Commit 0dbca01c authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Implemented (not working) workflow to share output array of...

Implemented (not working) workflow to share output array of generate_reference_cube between multiprocessing workers.
parent 06125776
Pipeline #1493 failed with stage
in 9 minutes and 57 seconds
......@@ -325,15 +325,29 @@ class KMeansRSImage(object):
return random_samples
global_shared_ref_cube = None
def mp_initializer_SpecHomo(ref_cube):
"""Declare global_shared_ref_cube needed for SpecHomo_Classifier.generate_reference_cube().
:param ref_cube: reference cube to be shared between multiprocessing workers
global global_shared_ref_cube
global_shared_ref_cube = ref_cube
class SpecHomo_Classifier(object):
def __init__(self, filelist_refs, v=False, logger=None):
def __init__(self, filelist_refs, v=False, logger=None, CPUs=None):
:param filelist_refs: list of reference images
self.ims_ref = filelist_refs
self.ref_cube = None
self.v = v
self.logger = logger or GMS_logger(__name__) # must be pickable
self.CPUs = CPUs
def generate_reference_cube(self, tgt_satellite, tgt_sensor, n_clusters=10, tgt_n_samples=1000, path_out='',
......@@ -363,23 +377,30 @@ class SpecHomo_Classifier(object):
if self.v:
# generate random spectra samples equally for each KMeans cluster
with Pool(processes=len(self.ims_ref)) as pool:
args = [(im, tgt_srf, n_clusters, tgt_n_samples) for im in self.ims_ref]
random_samples_all_ims_list = pool.starmap(self._get_uniform_random_samples, args)
# Build the reference cube from the random samples of each image
# => rows: tgt_n_samples, columns: images, bands: spectral information'Combining randomly sampled spectra to a single reference cube...')
ref_cube = np.hstack(random_samples_all_ims_list)
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands))) # filled in pool
# generate random spectra samples equally for each KMeans cluster
args = [(im_n, im, tgt_srf, n_clusters, tgt_n_samples) for im_n, im in enumerate(self.ims_ref)]
if self.CPUs is None or self.CPUs > 1:
with Pool(processes=len(self.ims_ref), initializer=mp_initializer_SpecHomo,
initargs=(self.ref_cube,)) as pool:
pool.starmap(self._get_uniform_random_samples, args)
self.ref_cube = global_shared_ref_cube
for argset in args:
# save
if path_out:
GeoArray(ref_cube).save(out_path=path_out, fmt=fmt_out)
GeoArray(self.ref_cube).save(out_path=path_out, fmt=fmt_out)
return ref_cube
return self.ref_cube
def _get_uniform_random_samples(self, im_ref, tgt_srf, n_clusters, tgt_n_samples):
def _get_uniform_random_samples(self, im_n, im_ref, tgt_srf, n_clusters, tgt_n_samples):
im_name = os.path.basename(im_ref)
# read input image
......@@ -414,6 +435,10 @@ class SpecHomo_Classifier(object):
random_samples = np.vstack([random_samples[clusterlabel] for clusterlabel in random_samples])
# reshape it so that we have the spectral information in 3rd dimension (x rows, 1 column and z bands)
random_samples = random_samples.reshape((random_samples.shape[0], 1, random_samples.shape[1]))
#random_samples = random_samples.reshape((random_samples.shape[0], 1, random_samples.shape[1]))
return random_samples
# copy resulting random samples into self.ref_cube
if self.CPUs is None or self.CPUs > 1:
global_shared_ref_cube[:, im_n, :] = random_samples
self.ref_cube[:, im_n, :] = random_samples
......@@ -37,3 +37,18 @@ class Test_SpecHomo_Classifier(unittest.TestCase):
def test_generate_reference_cube_S2A(self):
ref_cube = self.SHC.generate_reference_cube('Sentinel-2A', 'MSI', n_clusters=10, tgt_n_samples=1000)
self.assertIsInstance(ref_cube, np.ndarray)
def test_generate_reference_cube_AST(self):
ref_cube = self.SHC.generate_reference_cube('Terra', 'ASTER', n_clusters=10, tgt_n_samples=1000)
self.assertIsInstance(ref_cube, np.ndarray)
def test_multiprocessing(self):
SHC = SpecHomo_Classifier([testdata, testdata, ], v=False, CPUs=1)
ref_cube_sp = SHC.generate_reference_cube('Landsat-8', 'OLI_TIRS', n_clusters=10, tgt_n_samples=1000)
SHC = SpecHomo_Classifier([testdata, testdata, ], v=False, CPUs=None)
ref_cube_mp = SHC.generate_reference_cube('Landsat-8', 'OLI_TIRS', n_clusters=10, tgt_n_samples=1000)
self.assertTrue(np.array_equal(ref_cube_sp, ref_cube_mp),
msg='Singleprocessing result is not equal to multiprocessing result.')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment