Commit 465a5063 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Fix for returning an array with the wrong number of bands if...


Fix for returning an array with the wrong number of bands if vswir_overlap_algorithm='swir_only'. Added class Test_VNIR_SWIR_Stacker. Added further input validation to VNIR_SWIR_Stacker.
Signed-off-by: Daniel Scheffler's avatarDaniel Scheffler <danschef@gfz-potsdam.de>
parent 7dbbe6c8
Pipeline #7271 passed with stage
in 49 minutes and 14 seconds
......@@ -32,10 +32,11 @@ EnPT module 'orthorectification' for transforming an EnMAP image from sensor to
based on a pixel- and band-wise coordinate-layer (geolayer).
"""
from .orthorectification import Orthorectifier
from .orthorectification import Orthorectifier, VNIR_SWIR_Stacker
__author__ = 'Daniel Scheffler'
__all__ = [
'__author__',
'Orthorectifier'
'Orthorectifier',
'VNIR_SWIR_Stacker'
]
......@@ -210,6 +210,12 @@ class VNIR_SWIR_Stacker(object):
raise ValueError((self.vnir.gt, self.swir.gt), 'VNIR and SWIR geoinformation should be equal.')
if not prj_equal(self.vnir.prj, self.swir.prj):
raise ValueError((self.vnir.prj, self.swir.prj), 'VNIR and SWIR projection should be equal.')
if self.vnir.bands != len(self.wvls.vnir):
raise ValueError("The number of VNIR bands must be equal to the number of elements in 'vnir_wvls': "
"%d != %d" % (self.vnir.bands, len(self.wvls.vnir)))
if self.swir.bands != len(self.wvls.swir):
raise ValueError("The number of SWIR bands must be equal to the number of elements in 'swir_wvls': "
"%d != %d" % (self.swir.bands, len(self.wvls.swir)))
def _get_stack_order_by_wvl(self) -> Tuple[np.ndarray, np.ndarray]:
"""Stack bands ordered by wavelengths."""
......@@ -259,7 +265,7 @@ class VNIR_SWIR_Stacker(object):
wvls_vswir_sorted = np.hstack([wvls_vnir_cut, self.wvls.swir])
idx_vnir_lastband = np.argmin(np.abs(self.wvls.vnir - wvls_vnir_cut.max()))
return np.dstack([self.vnir[:, :, :idx_vnir_lastband], self.swir[:]]), wvls_vswir_sorted
return np.dstack([self.vnir[:, :, :idx_vnir_lastband + 1], self.swir[:]]), wvls_vswir_sorted
def compute_stack(self, algorithm: str) -> GeoArray:
"""Stack VNIR and SWIR bands with respect to their spectral overlap.
......
......@@ -40,9 +40,13 @@ from unittest import TestCase
from zipfile import ZipFile
import tempfile
import shutil
from copy import deepcopy
import numpy as np
from geoarray import GeoArray
from py_tools_ds.geo.projection import EPSG2WKT
from enpt.processors.orthorectification import Orthorectifier
from enpt.processors.orthorectification import Orthorectifier, VNIR_SWIR_Stacker
from enpt.options.config import config_for_testing, config_for_testing_dlr, EnPTConfig
from enpt.io.reader import L1B_Reader
from enpt.model.images import EnMAPL2Product_MapGeo
......@@ -52,7 +56,7 @@ __author__ = 'Daniel Scheffler'
class Test_Orthorectifier(TestCase):
def setUp(self):
self.config = EnPTConfig(**config_for_testing) # FIXME still the Alpine dataset
self.config = EnPTConfig(**config_for_testing)
# create a temporary directory
# NOTE: This must exist during the whole runtime of Test_Orthorectifier, otherwise
......@@ -117,3 +121,71 @@ class Test_Orthorectifier_DLR(TestCase):
np.mean(L2_obj.data[:, :, 0][L2_obj.data[:, :, 0] != L2_obj.data.nodata]),
rtol=0.01
))
class Test_VNIR_SWIR_Stacker(TestCase):
def setUp(self):
self.vnir_gA = GeoArray(np.random.randint(0, 255, (10, 10, 10)),
geotransform=(331185.0, 30.0, -0.0, 5840115.0, -0.0, -30.0),
projection=EPSG2WKT(32633))
self.swir_gA = GeoArray(np.random.randint(0, 255, (10, 10, 20)),
geotransform=(331185.0, 30.0, -0.0, 5840115.0, -0.0, -30.0),
projection=EPSG2WKT(32633))
self.vnir_wvls = np.arange(900, 1000, 10)
self.swir_wvls = np.arange(935, 1135, 10)
self.VSSt = VNIR_SWIR_Stacker(vnir=self.vnir_gA, swir=self.swir_gA,
vnir_wvls=self.vnir_wvls, swir_wvls=self.swir_wvls)
def test_validate_input(self):
# unequal geotransform
swir_gA = deepcopy(self.swir_gA)
swir_gA.gt = (331185.0, 10.0, -0.0, 5840115.0, -0.0, -10.0)
with self.assertRaises(ValueError):
VNIR_SWIR_Stacker(vnir=self.vnir_gA, swir=swir_gA,
vnir_wvls=self.vnir_wvls, swir_wvls=self.swir_wvls)
# unequal projection
swir_gA = deepcopy(self.swir_gA)
swir_gA.prj = EPSG2WKT(32632)
with self.assertRaises(ValueError):
VNIR_SWIR_Stacker(vnir=self.vnir_gA, swir=swir_gA,
vnir_wvls=self.vnir_wvls, swir_wvls=self.swir_wvls)
# wrong length of provided wavelength
with self.assertRaises(ValueError):
VNIR_SWIR_Stacker(vnir=self.vnir_gA, swir=self.swir_gA,
vnir_wvls=np.array(list(self.vnir_wvls) + [1]), swir_wvls=self.swir_wvls)
with self.assertRaises(ValueError):
VNIR_SWIR_Stacker(vnir=self.vnir_gA, swir=self.swir_gA,
vnir_wvls=self.vnir_wvls, swir_wvls=np.array(list(self.swir_wvls) + [1]))
def validate_output(self, gA_stacked: GeoArray):
self.assertIsInstance(gA_stacked, GeoArray)
self.assertEqual(gA_stacked.gt, self.vnir_gA.gt)
self.assertEqual(gA_stacked.prj, self.vnir_gA.prj)
self.assertEqual(gA_stacked.shape[:2], self.vnir_gA.shape[:2])
self.assertTrue('wavelength' in gA_stacked.meta.band_meta and
gA_stacked.meta.band_meta['wavelength'])
self.assertEqual(gA_stacked.bands, len(gA_stacked.meta.band_meta['wavelength']))
def test_get_stack_order_by_wvl(self):
gA_stacked = self.VSSt.compute_stack(algorithm='order_by_wvl')
self.validate_output(gA_stacked)
def test_get_stack_average(self):
gA_stacked = self.VSSt.compute_stack(algorithm='average')
self.validate_output(gA_stacked)
def test_get_stack_vnir_only(self):
gA_stacked = self.VSSt.compute_stack(algorithm='vnir_only')
self.validate_output(gA_stacked)
def test_get_stack_swir_only(self):
gA_stacked = self.VSSt.compute_stack(algorithm='swir_only')
self.validate_output(gA_stacked)
def test_compute_stack(self):
# wrong input algorithm
with self.assertRaises(ValueError):
self.VSSt.compute_stack(algorithm='mean')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment