diff --git a/enpt/processors/dead_pixel_correction/dead_pixel_correction.py b/enpt/processors/dead_pixel_correction/dead_pixel_correction.py index 13ffb7401ed3e2d27384a418c766d7d2b15623e0..f75a4443c0eff11f5a28763cdcdf3f13863776cc 100644 --- a/enpt/processors/dead_pixel_correction/dead_pixel_correction.py +++ b/enpt/processors/dead_pixel_correction/dead_pixel_correction.py @@ -38,7 +38,7 @@ import logging import numpy as np import numpy_indexed as npi from multiprocessing import Pool, cpu_count -from scipy.interpolate import griddata, interp1d +from scipy.interpolate import griddata, make_interp_spline from pandas import DataFrame from geoarray import GeoArray @@ -70,8 +70,7 @@ class Dead_Pixel_Corrector(object): :param algorithm: algorithm how to correct dead pixels 'spectral': interpolate in the spectral domain 'spatial': interpolate in the spatial domain - :param interp_spectral: spectral interpolation algorithm - (‘linear’, ‘nearest’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, etc.) + :param interp_spectral: spectral interpolation algorithm (‘linear’, ‘quadratic’, ‘cubic’) :param interp_spatial: spatial interpolation algorithm ('linear', 'bilinear', 'cubic', 'spline') :param CPUs: number of CPUs to use for interpolation (only relevant if algorithm = 'spatial') :param logger: @@ -106,7 +105,7 @@ class Dead_Pixel_Corrector(object): raise ValueError("Dead pixel map and image to be corrected must have equal shape.") image_corrected = interp_nodata_along_axis(image2correct, axis=2, nodata=deadpixel_map[:], - method=self.interp_alg_spectral, fill_value='extrapolate') + method=self.interp_alg_spectral) return image_corrected @@ -136,7 +135,7 @@ class Dead_Pixel_Corrector(object): # correct remaining nodata by spectral interpolation (e.g., outermost columns) if np.isnan(image2correct).any(): image2correct = interp_nodata_along_axis(image2correct, axis=2, nodata=np.isnan(image2correct), - method=self.interp_alg_spectral, fill_value='extrapolate') + method=self.interp_alg_spectral) return image2correct @@ -205,21 +204,24 @@ def interp_nodata_along_axis_2d(data_2d: np.ndarray, axis: int = 0, nodata: Union[np.ndarray, Number] = np.nan, method: str = 'linear', - fill_value: Union[float, str] = 'extrapolate'): - """Interpolate a 2D array along the given axis (based on scipy.interpolate.interp1d). + **kw): + """Interpolate a 2D array along the given axis (based on scipy.interpolate.make_interp_spline). :param data_2d: data to interpolate :param axis: axis to interpolate (0: along columns; 1: along rows) :param nodata: nodata array in the shape of data or nodata value - :param method: interpolation method (‘linear’, ‘nearest’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, etc.) - :param fill_value: value to fill into positions where no interpolation is possible - - if 'extrapolate': extrapolate the missing values + :param method: interpolation method (‘linear’, ‘quadratic’, ‘cubic’) + :param kw: keyword arguments to be passed to scipy.interpolate.make_interp_spline :return: interpolated array """ if data_2d.ndim != 2: raise ValueError('Expected a 2D array. Received a %dD array.' % data_2d.ndim) if axis > data_2d.ndim: raise ValueError("axis=%d is out of bounds for data with %d dimensions." % (axis, data_2d.ndim)) + if method not in ['linear', 'quadratic', 'cubic']: + raise ValueError(f"'{method}' is not a valid interpolation method. " + f"Choose between 'linear', 'quadratic', and 'cubic'.") + degree = 1 if method == 'linear' else 2 if method == 'quadratic' else 3 data_2d = data_2d if axis == 1 else data_2d.T @@ -240,10 +242,8 @@ def interp_nodata_along_axis_2d(data_2d: np.ndarray, if goodpos.size > 1: data_2d_grouped_rows = data_2d[indices_unique_rows] - data_2d_grouped_rows[:, badpos] = \ - interp1d(goodpos, data_2d_grouped_rows[:, goodpos], - axis=1, kind=method, fill_value=fill_value, bounds_error=False)(badpos) + make_interp_spline(goodpos, data_2d_grouped_rows[:, goodpos], axis=1, k=degree, **kw)(badpos) data_2d[indices_unique_rows, :] = data_2d_grouped_rows @@ -254,15 +254,14 @@ def interp_nodata_along_axis(data, axis=0, nodata: Union[np.ndarray, Number] = np.nan, method: str = 'linear', - fill_value: Union[float, str] = 'extrapolate'): - """Interpolate a 2D or 3D array along the given axis (based on scipy.interpolate.interp1d). + **kw): + """Interpolate a 2D or 3D array along the given axis (based on scipy.interpolate.make_interp_spline). :param data: data to interpolate :param axis: axis to interpolate (0: along columns; 1: along rows, 2: along bands) :param nodata: nodata array in the shape of data or nodata value - :param method: interpolation method (‘linear’, ‘nearest’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, etc.) - :param fill_value: value to fill into positions where no interpolation is possible - - if 'extrapolate': extrapolate the missing values + :param method: interpolation method (‘linear’, 'quadratic', 'cubic') + :param kw: keyword arguments to be passed to scipy.interpolate.make_interp_spline :return: interpolated array """ assert axis <= 2 @@ -272,7 +271,7 @@ def interp_nodata_along_axis(data, raise ValueError('No-data mask and data must have the same shape.') if data.ndim == 2: - return interp_nodata_along_axis_2d(data, axis=axis, nodata=nodata, method=method, fill_value=fill_value) + return interp_nodata_along_axis_2d(data, axis=axis, nodata=nodata, method=method, **kw) else: def reshape_input(In): @@ -293,7 +292,7 @@ def interp_nodata_along_axis(data, data_2d=reshape_input(data), nodata=reshape_input(nodata) if isinstance(nodata, np.ndarray) else nodata, axis=axis if axis != 2 else 1, - method=method, fill_value=fill_value)) + method=method, **kw)) def interp_nodata_spatially_2d(data_2d: np.ndarray, diff --git a/tests/test_dead_pixel_correction.py b/tests/test_dead_pixel_correction.py index 75eb2691fd8b4f2625c1183013020e9224c1a8e0..3c82d331d667c8c848a1c11222cd1b5560f91e94 100644 --- a/tests/test_dead_pixel_correction.py +++ b/tests/test_dead_pixel_correction.py @@ -126,10 +126,6 @@ class Test_interp_nodata_along_axis_2d(TestCase): data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=0, nodata=mask_nodata, method='linear') assert np.array_equal(data_int, arr_exp), 'Computed %s.' % data_int - data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=0, method='linear', fill_value=-1) - arr_exp = np.array([[0, 0, 2], [3, 5, 5], [-1, 10, 8]]) - assert np.array_equal(data_int, arr_exp), 'Computed %s.' % data_int - def test_axis_1(self): data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=1, method='linear') arr_exp = np.array([[0, 0, 2], [3, 4, 5], [12, 10, 8]]) @@ -139,10 +135,6 @@ class Test_interp_nodata_along_axis_2d(TestCase): data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=1, nodata=mask_nodata, method='linear') assert np.array_equal(data_int, arr_exp), 'Computed %s.' % data_int - data_int = interp_nodata_along_axis_2d(self.get_data2d(), axis=1, method='linear', fill_value=-1) - arr_exp = np.array([[0, 0, 2], [3, 4, 5], [-1, 10, 8]]) - assert np.array_equal(data_int, arr_exp), 'Computed %s.' % data_int - def test_bad_args(self): with pytest.raises(ValueError): interp_nodata_along_axis_2d(self.get_data2d(), axis=3)