Commit e9ff3586 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Generation of reference cubes now works in multiprocessing. Added tqdm to dependencies.

Former-commit-id: 6a996afc
Former-commit-id: 136bba65
parent 44c65532
......@@ -11,9 +11,10 @@ import scipy as sp
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from pandas import DataFrame
from typing import Union # noqa F401 # flake8 issue
from typing import Union, List # noqa F401 # flake8 issue
from tqdm import tqdm
from multiprocessing import Pool
from multiprocessing import Pool, cpu_count
import tempfile
from sklearn.cluster import k_means_ # noqa F401 # flake8 issue
from geoarray import GeoArray # noqa F401 # flake8 issue
......@@ -325,20 +326,9 @@ class KMeansRSImage(object):
return random_samples
global_shared_ref_cube = None
def mp_initializer_SpecHomo(ref_cube):
"""Declare global_shared_ref_cube needed for SpecHomo_Classifier.generate_reference_cube().
:param ref_cube: reference cube to be shared between multiprocessing workers
global global_shared_ref_cube
global_shared_ref_cube = ref_cube
class SpecHomo_Classifier(object):
def __init__(self, filelist_refs, v=False, logger=None, CPUs=None):
# type: (List[str], bool, logging.Logger, Union[None, int]) -> None
:param filelist_refs: list of reference images
......@@ -347,7 +337,8 @@ class SpecHomo_Classifier(object):
self.ref_cube = None
self.v = v
self.logger = logger or GMS_logger(__name__) # must be pickable
self.CPUs = CPUs
self.CPUs = CPUs or cpu_count()
self.tmpdir_multiproc = ''
def generate_reference_cube(self, tgt_satellite, tgt_sensor, n_clusters=10, tgt_n_samples=1000, path_out='',
......@@ -377,22 +368,34 @@ class SpecHomo_Classifier(object):
if self.v:
# Build the reference cube from the random samples of each image
# Build the reference cube from random samples of each image
# => rows: tgt_n_samples, columns: images, bands: spectral information
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands))) # filled in pool
# generate random spectra samples equally for each KMeans cluster
args = [(im_n, im, tgt_srf, n_clusters, tgt_n_samples) for im_n, im in enumerate(self.ims_ref)]
args = [(im, tgt_srf, n_clusters, tgt_n_samples) for im in self.ims_ref]
if self.CPUs > 1:
processes = len(self.ims_ref) if self.CPUs > len(self.ims_ref) else self.CPUs
with tempfile.TemporaryDirectory() as tmpdir:
self.tmpdir_multiproc = tmpdir
if self.CPUs is None or self.CPUs > 1:
with Pool(processes=len(self.ims_ref), initializer=mp_initializer_SpecHomo,
initargs=(self.ref_cube,)) as pool:
pool.starmap(self._get_uniform_random_samples, args)
self.ref_cube = global_shared_ref_cube
with Pool(processes) as pool:
pool.starmap(self._get_uniform_random_samples, args)
# combine temporarily saved random samples to ref_cube'Combining random samples to reference cube...')
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands)))
for im_n, im in enumerate(self.ims_ref):
path_randsampl = os.path.join(tmpdir, 'random_samples', os.path.basename(im))'Adding content of %s to reference cube...' % im)
self.ref_cube[:, im_n, :] = GeoArray(path_randsampl)[:]
for argset in args:
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands)))
for im_n, argset in enumerate(args):
# combine returned random samples to ref_cube
random_samples = self._get_uniform_random_samples(*argset)'Adding random samples for %s to reference cube...' % argset[0])
self.ref_cube[:, im_n, :] = random_samples
# save
if path_out:
......@@ -400,14 +403,13 @@ class SpecHomo_Classifier(object):
return self.ref_cube
def _get_uniform_random_samples(self, im_n, im_ref, tgt_srf, n_clusters, tgt_n_samples):
def _get_uniform_random_samples(self, im_ref, tgt_srf, n_clusters, tgt_n_samples):
im_name = os.path.basename(im_ref)
# read input image'Reading the input image %s...' % im_name)
im_gA = GeoArray(im_ref)
im_gA.cwl = np.array(im_gA.meta.loc['wavelength'], dtype=np.float).flatten()
# im_gA.to_mem()
wvl_unit = 'nanometers' if max(im_gA.cwl) > 15 else 'micrometers'
# perform spectral resampling of input image to match spectral properties of target sensor
......@@ -415,7 +417,8 @@ class SpecHomo_Classifier(object):
SR = SpectralResampler(im_gA.cwl, tgt_srf, wvl_unit=wvl_unit)
im_tgt = np.empty((*im_gA.shape[:2], len(tgt_srf.bands)))
for ((rS, rE), (cS, cE)), tiledata in tqdm(im_gA.tiles((1000, 1000))):
tiles = im_gA.tiles((1000, 1000))
for ((rS, rE), (cS, cE)), tiledata in (tqdm(tiles) if self.CPUs == 1 else tiles):
im_tgt[rS: rE + 1, cS: cE + 1, :] = SR.resample_image(tiledata)
im_tgt = GeoArray(im_tgt,, im_gA.prj)
......@@ -434,11 +437,8 @@ class SpecHomo_Classifier(object):
# combine the spectra (2D arrays) of all clusters to a single 2D array
random_samples = np.vstack([random_samples[clusterlabel] for clusterlabel in random_samples])
# reshape it so that we have the spectral information in 3rd dimension (x rows, 1 column and z bands)
#random_samples = random_samples.reshape((random_samples.shape[0], 1, random_samples.shape[1]))
# copy resulting random samples into self.ref_cube
if self.CPUs is None or self.CPUs > 1:
global_shared_ref_cube[:, im_n, :] = random_samples
# return random samples or cache them on disk in multiprocessing
if self.CPUs > 1:
GeoArray(random_samples, nodata=-9999).save(os.path.join(self.tmpdir_multiproc, 'random_samples', im_name))
self.ref_cube[:, im_n, :] = random_samples
return random_samples
......@@ -25,3 +25,4 @@ sqlalchemy
# fmask # not pip installable
......@@ -11,3 +11,4 @@ pyinstrument
......@@ -14,7 +14,7 @@ with open('HISTORY.rst') as history_file:
requirements = [
'matplotlib', 'numpy', 'scikit-learn', 'scipy', 'gdal', 'pyproj', 'shapely', 'ephem', 'pyorbital', 'dill', 'pytz',
'pandas', 'numba', 'spectral>=0.16', 'geopandas', 'iso8601', 'pyinstrument', 'geoalchemy2', 'sqlalchemy',
'psycopg2', 'py_tools_ds>=0.10.0', 'geoarray>=0.7.1', 'arosics>=0.6.6', 'six'
'psycopg2', 'py_tools_ds>=0.10.0', 'geoarray>=0.7.1', 'arosics>=0.6.6', 'six', 'tqdm'
# spectral<0.16 has some problems with writing signed integer 8bit data
# fmask # conda install -c conda-forge python-fmask
# 'pyhdf', # conda install --yes -c conda-forge pyhdf
......@@ -16,10 +16,10 @@ RUN /bin/bash -i -c "source /root/anaconda3/bin/activate ; \
conda install --yes -c conda-forge numpy gdal scikit-image scikit-learn matplotlib pyproj rasterio shapely basemap \
pykrige glymur pygrib pyproj cachetools pyhdf ephem python-fmask scipy ; \
conda install --yes -c conda-forge 'icu=58.*' lxml ; \
pip install pandas geopandas dicttoxml jsmin cerberus pyprind pint iso8601 tqdm mpld3 sphinx-argparse dill pytz \
pip install pandas geopandas dicttoxml jsmin cerberus pyprind pint iso8601 mpld3 sphinx-argparse dill pytz \
spectral>0.16 psycopg2 pyorbital pyinstrument geoalchemy2 sqlalchemy py_tools_ds>=0.10.0 \
geoarray>=0.7.1 arosics>=0.6.6 flake8 pycodestyle pylint pydocstyle nose nose2 nose-htmloutput \
coverage rednose six" # must include all the requirements needed to build the docs!
coverage rednose six tqdm" # must include all the requirements needed to build the docs!
# copy some needed stuff to /root
#COPY *.pkl /root/ # EXAMPLE
......@@ -33,10 +33,12 @@ class Test_SpecHomo_Classifier(unittest.TestCase):
def test_generate_reference_cube_L8(self):
ref_cube = self.SHC.generate_reference_cube('Landsat-8', 'OLI_TIRS', n_clusters=10, tgt_n_samples=1000)
self.assertIsInstance(ref_cube, np.ndarray)
self.assertTrue(np.any(ref_cube), msg='Reference cube for Landsat-8 is empty.')
def test_generate_reference_cube_S2A(self):
ref_cube = self.SHC.generate_reference_cube('Sentinel-2A', 'MSI', n_clusters=10, tgt_n_samples=1000)
self.assertIsInstance(ref_cube, np.ndarray)
self.assertTrue(np.any(ref_cube), msg='Reference cube for Sentinel-2A is empty.')
def test_generate_reference_cube_AST(self):
......@@ -50,5 +52,5 @@ class Test_SpecHomo_Classifier(unittest.TestCase):
SHC = SpecHomo_Classifier([testdata, testdata, ], v=False, CPUs=None)
ref_cube_mp = SHC.generate_reference_cube('Landsat-8', 'OLI_TIRS', n_clusters=10, tgt_n_samples=1000)
self.assertTrue(np.array_equal(ref_cube_sp, ref_cube_mp),
msg='Singleprocessing result is not equal to multiprocessing result.')
self.assertTrue(np.any(ref_cube_sp), msg='Singleprocessing result is empty.')
self.assertTrue(np.any(ref_cube_mp), msg='Multiprocessing result is empty.')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment