From adc1d6bf8ee8e6ca3738ec7724a0129d5252598a Mon Sep 17 00:00:00 2001
From: Daniel Scheffler <danschef@gfz-potsdam.de>
Date: Tue, 5 Mar 2024 15:16:08 +0100
Subject: [PATCH] Replace interp1d in dead_pixel_correction.py.

Signed-off-by: Daniel Scheffler <danschef@gfz-potsdam.de>
---
 .../dead_pixel_correction.py                  | 39 +++++++++----------
 tests/test_dead_pixel_correction.py           |  8 ----
 2 files changed, 19 insertions(+), 28 deletions(-)

diff --git a/enpt/processors/dead_pixel_correction/dead_pixel_correction.py b/enpt/processors/dead_pixel_correction/dead_pixel_correction.py
index 13ffb740..f75a4443 100644
--- a/enpt/processors/dead_pixel_correction/dead_pixel_correction.py
+++ b/enpt/processors/dead_pixel_correction/dead_pixel_correction.py
@@ -38,7 +38,7 @@ import logging
 import numpy as np
 import numpy_indexed as npi
 from multiprocessing import Pool, cpu_count
-from scipy.interpolate import griddata, interp1d
+from scipy.interpolate import griddata, make_interp_spline
 from pandas import DataFrame
 from geoarray import GeoArray
 
@@ -70,8 +70,7 @@ class Dead_Pixel_Corrector(object):
         :param algorithm:           algorithm how to correct dead pixels
                                     'spectral': interpolate in the spectral domain
                                     'spatial':  interpolate in the spatial domain
-        :param interp_spectral:     spectral interpolation algorithm
-                                    (‘linear’, ‘nearest’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, etc.)
+        :param interp_spectral:     spectral interpolation algorithm (‘linear’, ‘quadratic’, ‘cubic’)
         :param interp_spatial:      spatial interpolation algorithm ('linear', 'bilinear', 'cubic', 'spline')
         :param CPUs:                number of CPUs to use for interpolation (only relevant if algorithm = 'spatial')
         :param logger:
@@ -106,7 +105,7 @@ class Dead_Pixel_Corrector(object):
             raise ValueError("Dead pixel map and image to be corrected must have equal shape.")
 
         image_corrected = interp_nodata_along_axis(image2correct, axis=2, nodata=deadpixel_map[:],
-                                                   method=self.interp_alg_spectral, fill_value='extrapolate')
+                                                   method=self.interp_alg_spectral)
 
         return image_corrected
 
@@ -136,7 +135,7 @@ class Dead_Pixel_Corrector(object):
         # correct remaining nodata by spectral interpolation (e.g., outermost columns)
         if np.isnan(image2correct).any():
             image2correct = interp_nodata_along_axis(image2correct, axis=2, nodata=np.isnan(image2correct),
-                                                     method=self.interp_alg_spectral, fill_value='extrapolate')
+                                                     method=self.interp_alg_spectral)
 
         return image2correct
 
@@ -205,21 +204,24 @@ def interp_nodata_along_axis_2d(data_2d: np.ndarray,
                                 axis: int = 0,
                                 nodata: Union[np.ndarray, Number] = np.nan,
                                 method: str = 'linear',
-                                fill_value: Union[float, str] = 'extrapolate'):
-    """Interpolate a 2D array along the given axis (based on scipy.interpolate.interp1d).
+                                **kw):
+    """Interpolate a 2D array along the given axis (based on scipy.interpolate.make_interp_spline).
 
     :param data_2d:         data to interpolate
     :param axis:            axis to interpolate (0: along columns; 1: along rows)
     :param nodata:          nodata array in the shape of data or nodata value
-    :param method:          interpolation method (‘linear’, ‘nearest’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, etc.)
-    :param fill_value:      value to fill into positions where no interpolation is possible
-                            - if 'extrapolate': extrapolate the missing values
+    :param method:          interpolation method (‘linear’, ‘quadratic’, ‘cubic’)
+    :param kw:              keyword arguments to be passed to scipy.interpolate.make_interp_spline
     :return:    interpolated array
     """
     if data_2d.ndim != 2:
         raise ValueError('Expected a 2D array. Received a %dD array.' % data_2d.ndim)
     if axis > data_2d.ndim:
         raise ValueError("axis=%d is out of bounds for data with %d dimensions." % (axis, data_2d.ndim))
+    if method not in ['linear', 'quadratic', 'cubic']:
+        raise ValueError(f"'{method}' is not a valid interpolation method. "
+                         f"Choose between 'linear', 'quadratic', and 'cubic'.")
+    degree = 1 if method == 'linear' else 2 if method == 'quadratic' else 3
 
     data_2d = data_2d if axis == 1 else data_2d.T
 
@@ -240,10 +242,8 @@ def interp_nodata_along_axis_2d(data_2d: np.ndarray,
 
             if goodpos.size > 1:
                 data_2d_grouped_rows = data_2d[indices_unique_rows]
-
                 data_2d_grouped_rows[:, badpos] = \
-                    interp1d(goodpos, data_2d_grouped_rows[:, goodpos],
-                             axis=1, kind=method, fill_value=fill_value, bounds_error=False)(badpos)
+                    make_interp_spline(goodpos, data_2d_grouped_rows[:, goodpos], axis=1, k=degree, **kw)(badpos)
 
                 data_2d[indices_unique_rows, :] = data_2d_grouped_rows
 
@@ -254,15 +254,14 @@ def interp_nodata_along_axis(data,
                              axis=0,
                              nodata: Union[np.ndarray, Number] = np.nan,
                              method: str = 'linear',
-                             fill_value: Union[float, str] = 'extrapolate'):
-    """Interpolate a 2D or 3D array along the given axis (based on scipy.interpolate.interp1d).
+                             **kw):
+    """Interpolate a 2D or 3D array along the given axis (based on scipy.interpolate.make_interp_spline).
 
     :param data:            data to interpolate
     :param axis:            axis to interpolate (0: along columns; 1: along rows, 2: along bands)
     :param nodata:          nodata array in the shape of data or nodata value
-    :param method:          interpolation method (‘linear’, ‘nearest’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, etc.)
-    :param fill_value:      value to fill into positions where no interpolation is possible
-                            - if 'extrapolate': extrapolate the missing values
+    :param method:          interpolation method (‘linear’, 'quadratic', 'cubic')
+    :param kw:              keyword arguments to be passed to scipy.interpolate.make_interp_spline
     :return:    interpolated array
     """
     assert axis <= 2
@@ -272,7 +271,7 @@ def interp_nodata_along_axis(data,
         raise ValueError('No-data mask and data must have the same shape.')
 
     if data.ndim == 2:
-        return interp_nodata_along_axis_2d(data, axis=axis, nodata=nodata, method=method, fill_value=fill_value)
+        return interp_nodata_along_axis_2d(data, axis=axis, nodata=nodata, method=method, **kw)
 
     else:
         def reshape_input(In):
@@ -293,7 +292,7 @@ def interp_nodata_along_axis(data,
                     data_2d=reshape_input(data),
                     nodata=reshape_input(nodata) if isinstance(nodata, np.ndarray) else nodata,
                     axis=axis if axis != 2 else 1,
-                    method=method, fill_value=fill_value))
+                    method=method, **kw))
 
 
 def interp_nodata_spatially_2d(data_2d: np.ndarray,
diff --git a/tests/test_dead_pixel_correction.py b/tests/test_dead_pixel_correction.py
index 75eb2691..3c82d331 100644
--- a/tests/test_dead_pixel_correction.py
+++ b/tests/test_dead_pixel_correction.py
@@ -126,10 +126,6 @@ class Test_interp_nodata_along_axis_2d(TestCase):
         data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=0, nodata=mask_nodata, method='linear')
         assert np.array_equal(data_int, arr_exp), 'Computed %s.' % data_int
 
-        data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=0, method='linear', fill_value=-1)
-        arr_exp = np.array([[0, 0, 2], [3, 5, 5], [-1, 10, 8]])
-        assert np.array_equal(data_int, arr_exp), 'Computed %s.' % data_int
-
     def test_axis_1(self):
         data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=1, method='linear')
         arr_exp = np.array([[0, 0, 2], [3, 4, 5], [12, 10, 8]])
@@ -139,10 +135,6 @@ class Test_interp_nodata_along_axis_2d(TestCase):
         data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=1, nodata=mask_nodata, method='linear')
         assert np.array_equal(data_int, arr_exp), 'Computed %s.' % data_int
 
-        data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=1, method='linear', fill_value=-1)
-        arr_exp = np.array([[0, 0, 2], [3, 4, 5], [-1, 10, 8]])
-        assert np.array_equal(data_int, arr_exp), 'Computed %s.' % data_int
-
     def test_bad_args(self):
         with pytest.raises(ValueError):
             interp_nodata_along_axis_2d(self.get_data2d(), axis=3)
-- 
GitLab