Commit 8b532373 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Added RefCube class and placeholders for Classifier_Trainer, SpecHomo_Classifier.

parent a8542b88
Pipeline #1698 failed with stage
in 7 minutes and 56 seconds
......@@ -17,6 +17,8 @@ from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from glob import glob
import re
import json
from collections import OrderedDict
from sklearn.cluster import k_means_ # noqa F401 # flake8 issue
from sklearn.model_selection import train_test_split
......@@ -413,7 +415,7 @@ class RidgeRegression_RSImage(_MachineLearner_RSImage):
raise NotImplementedError()
super(RidgeRegression_RSImage, self).__init__(im_X, im_Y, test_size=test_size)
self.ridgeRegressor = LinearRegression().fit(self.train_X, self.train_Y)
self.ridgeRegressor = RidgeClassifier().fit(self.train_X, self.train_Y)
@property
def coefficients_(self):
......@@ -607,7 +609,8 @@ class ReferenceCube_Generator(object):
self.CPUs = CPUs or cpu_count()
# privates
self._refcubes = {sat_sen: None for sat_sen in self.tgt_sat_sen_list}
self._refcubes = \
{sat_sen: RefCube(satellite=sat_sen[0], sensor=sat_sen[1]) for sat_sen in self.tgt_sat_sen_list}
# validation
if dir_refcubes and not os.path.isdir(self.dir_refcubes):
......@@ -615,7 +618,7 @@ class ReferenceCube_Generator(object):
@property
def refcubes(self):
# type: () -> Dict[Tuple[str]: GeoArray]
# type: () -> Dict[Tuple[str, str]: RefCube]
if not self._refcubes:
# fill self._ref_cubes with GeoArray instances of already existing reference cubes read from disk
......@@ -632,7 +635,7 @@ class ReferenceCube_Generator(object):
# import the existing ref cube if it matches the target refcube specs
if correct_specs:
self._refcubes[(sat, sen)] = GeoArray(path_refcube)
self._refcubes[(sat, sen)] = RefCube(satellite=sat, sensor=sen, filepath=path_refcube)
return self._refcubes
......@@ -671,15 +674,14 @@ class ReferenceCube_Generator(object):
tgt_srf=tgt_srf)
# add the spectra as GeoArray instance to the in-mem ref cubes
self.add_spectra_to_refcube(unif_random_spectra_rsp,
tgt_sat_sen=(tgt_sat, tgt_sen), src_imname=src_im.basename)
refcube = self.refcubes[(tgt_sat, tgt_sen)] # type: RefCube
refcube.add_spectra(unif_random_spectra_rsp, src_imname=src_im.basename)
# update the reference cubes on disk
if self.dir_refcubes:
path_out = os.path.join(self.dir_refcubes, 'refcube__%s__%s__nclust%s__nsamp%s.bsq'
% (tgt_sat, tgt_sen, self.n_clusters, self.tgt_n_samples))
updated_refcube = self.refcubes[(tgt_sat, tgt_sen)] # type: GeoArray
updated_refcube.save(out_path=path_out, fmt=fmt_out)
refcube.save(path_out=os.path.join(self.dir_refcubes, 'refcube__%s__%s__nclust%s__nsamp%s.bsq'
% (tgt_sat, tgt_sen, self.n_clusters, self.tgt_n_samples)),
fmt=fmt_out)
return self.refcubes
......@@ -779,24 +781,126 @@ class ReferenceCube_Generator(object):
return tgt_im
def add_spectra_to_refcube(self, spectra, tgt_sat_sen, src_imname):
# type: (np.ndarray, tuple, str) -> None
"""Add a set of spectral signatures to the reference cube (in-mem variable self.refcubes).
:param spectra: 2D numpy array with rows: spectral samples / columns: spectral information (bands)
:param tgt_sat_sen: tuple of ('satellite', 'sensor')
:param src_imname: image basename of the source hyperspectral image
class RefCube(object):
def __init__(self, filepath='', satellite='', sensor=''):
# type: (str, str, str) -> None
# privates
self._col_imName_dict = dict()
# defaults
self.data = GeoArray(np.empty((0, 0, 0))) # type: GeoArray
self.srcImNames = []
# args/ kwargs
self.filepath = filepath
self.satellite = satellite
self.sensor = sensor
if filepath:
self.from_filepath(filepath)
@property
def n_images(self):
return self.data.shape[1]
@property
def n_signatures(self):
return self.data.shape[0]
@property
def col_imName_dict(self):
# type: () -> OrderedDict
return OrderedDict((col, imName) for col, imName in zip(range(self.n_images), self.srcImNames))
@col_imName_dict.setter
def col_imName_dict(self, col_imName_dict):
# type: (dict) -> None
self._col_imName_dict = col_imName_dict
self.srcImNames = list(col_imName_dict.values())
def add_refcube_array(self, refcube_array, src_imnames):
# type: (Union[str, np.ndarray], list) -> None
if self.data.size:
new_cube = np.hstack([self.data, refcube_array])
self.data = GeoArray(new_cube)
else:
self.data = GeoArray(refcube_array)
self.srcImNames.extend(src_imnames)
def add_spectra(self, spectra, src_imname):
# type: (np.ndarray, str) -> None
"""
:param spectra: 2D numpy array with rows: spectral samples / columns: spectral information (bands)
:param src_imname: image basename of the source hyperspectral image
"""
# reshape 2D spectra array to one image column (refcube is an image with spectral information in the 3rd dim.)
im_col = spectra.reshape(spectra.shape[0], 1, spectra.shape[1])
if self.refcubes[tgt_sat_sen] is not None:
# append spectra to existing reference cube
new_cube = np.hstack([self.refcubes[tgt_sat_sen], im_col])
self.refcubes[tgt_sat_sen] = GeoArray(new_cube)
if self.data.size:
# validation
if spectra.shape[0] != self.data.shape[0]:
raise ValueError('The number of signatures in the given spectra array does not match the dimensions of '
'the reference cube.')
# TODO add src_imname to GeoArray metadata
# append spectra to existing reference cube
new_cube = np.hstack([self.data, im_col])
self.data = GeoArray(new_cube)
else:
# add a new GeoArray instance containing given spectra as a single image line
self.refcubes[tgt_sat_sen] = GeoArray(im_col)
self.data = GeoArray(im_col)
# add source image name to list of image names
self.srcImNames.append(src_imname)
@property
def metadata(self):
attrs2include = ['satellite', 'sensor', 'filepath', 'n_signatures', 'n_images', 'col_imName_dict']
return OrderedDict((k, getattr(self, k)) for k in attrs2include)
def from_filepath(self, filepath):
self.data = GeoArray(filepath)
with open(os.path.splitext(filepath)[0] + '.meta', 'r') as metaF:
meta = json.load(metaF)
for k, v in meta.items():
if k in ['n_signatures', 'n_images']:
continue # skip pure getters
else:
setattr(self, k, v)
def save(self, path_out, fmt='ENVI'):
# type: (str, str) -> None
self.data.save(out_path=path_out, fmt=fmt)
# save metadata as JSON file
meta2write = self.metadata.copy()
meta2write['filepath'] = self.filepath or path_out
with open(os.path.splitext(path_out)[0] + '.meta', 'w') as metaF:
json.dump(meta2write, metaF, separators=(',', ': '), indent=4)
class Classifier_Trainer(object):
def __init__(self, src_sensor_refcube, tgt_sensor_refcube):
# type: (RefCube, RefCube) -> None
"""
:param src_sensor_refcube: file path of reference cube of source sensor
:param tgt_sensor_refcube: file path of reference cube of target sensor
"""
self.src_cube = src_sensor_refcube
self.tgt_cube = tgt_sensor_refcube
def plot_scattermatrix(self):
pass
class SpecHomo_Classifier(object):
def __init__(self):
pass
def from_refcubes(self, dir_refcubes):
pass
......@@ -18,6 +18,7 @@ from gms_preprocessing import __path__ # noqa E402 module level import not at t
from gms_preprocessing import set_config # noqa E402 module level import not at top of file
from gms_preprocessing.algorithms.L2B_P import ReferenceCube_Generator_OLD # noqa E402 module level import not at top of file
from gms_preprocessing.algorithms.L2B_P import ReferenceCube_Generator # noqa E402 module level import not at top of file
from gms_preprocessing.algorithms.L2B_P import RefCube # noqa E402 module level import not at top of file
testdata = os.path.join(__path__[0], '../tests/data/hy_spec_data/Bavaria_farmland_LMU_Hyspex_subset.bsq')
......@@ -112,8 +113,8 @@ class Test_ReferenceCube_Generator(unittest.TestCase):
def test_generate_reference_cube(self):
refcubes = self.SHC.generate_reference_cubes()
self.assertIsInstance(refcubes, dict)
self.assertIsInstance(refcubes[('Landsat-8', 'OLI_TIRS')], GeoArray)
self.assertEqual(refcubes[('Landsat-8', 'OLI_TIRS')].shape, (self.tgt_n_samples, len(self.testIms), 8))
self.assertIsInstance(refcubes[('Landsat-8', 'OLI_TIRS')], RefCube)
self.assertEqual(refcubes[('Landsat-8', 'OLI_TIRS')].data.shape, (self.tgt_n_samples, len(self.testIms), 8))
@unittest.SkipTest
def test_multiprocessing(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment