Commit 136bba65 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Generation of reference cubes now works in multiprocessing. Added tqdm to dependencies.


Former-commit-id: 6a996afc
parent 633f5097
...@@ -11,9 +11,10 @@ import scipy as sp ...@@ -11,9 +11,10 @@ import scipy as sp
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from sklearn.cluster import KMeans from sklearn.cluster import KMeans
from pandas import DataFrame from pandas import DataFrame
from typing import Union # noqa F401 # flake8 issue from typing import Union, List # noqa F401 # flake8 issue
from tqdm import tqdm from tqdm import tqdm
from multiprocessing import Pool from multiprocessing import Pool, cpu_count
import tempfile
from sklearn.cluster import k_means_ # noqa F401 # flake8 issue from sklearn.cluster import k_means_ # noqa F401 # flake8 issue
from geoarray import GeoArray # noqa F401 # flake8 issue from geoarray import GeoArray # noqa F401 # flake8 issue
...@@ -325,20 +326,9 @@ class KMeansRSImage(object): ...@@ -325,20 +326,9 @@ class KMeansRSImage(object):
return random_samples return random_samples
global_shared_ref_cube = None
def mp_initializer_SpecHomo(ref_cube):
"""Declare global_shared_ref_cube needed for SpecHomo_Classifier.generate_reference_cube().
:param ref_cube: reference cube to be shared between multiprocessing workers
"""
global global_shared_ref_cube
global_shared_ref_cube = ref_cube
class SpecHomo_Classifier(object): class SpecHomo_Classifier(object):
def __init__(self, filelist_refs, v=False, logger=None, CPUs=None): def __init__(self, filelist_refs, v=False, logger=None, CPUs=None):
# type: (List[str], bool, logging.Logger, Union[None, int]) -> None
""" """
:param filelist_refs: list of reference images :param filelist_refs: list of reference images
...@@ -347,7 +337,8 @@ class SpecHomo_Classifier(object): ...@@ -347,7 +337,8 @@ class SpecHomo_Classifier(object):
self.ref_cube = None self.ref_cube = None
self.v = v self.v = v
self.logger = logger or GMS_logger(__name__) # must be pickable self.logger = logger or GMS_logger(__name__) # must be pickable
self.CPUs = CPUs self.CPUs = CPUs or cpu_count()
self.tmpdir_multiproc = ''
def generate_reference_cube(self, tgt_satellite, tgt_sensor, n_clusters=10, tgt_n_samples=1000, path_out='', def generate_reference_cube(self, tgt_satellite, tgt_sensor, n_clusters=10, tgt_n_samples=1000, path_out='',
fmt_out='ENVI'): fmt_out='ENVI'):
...@@ -377,22 +368,34 @@ class SpecHomo_Classifier(object): ...@@ -377,22 +368,34 @@ class SpecHomo_Classifier(object):
if self.v: if self.v:
tgt_srf.plot_srfs() tgt_srf.plot_srfs()
# Build the reference cube from the random samples of each image # Build the reference cube from random samples of each image
# => rows: tgt_n_samples, columns: images, bands: spectral information # => rows: tgt_n_samples, columns: images, bands: spectral information
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands))) # filled in pool
# generate random spectra samples equally for each KMeans cluster # generate random spectra samples equally for each KMeans cluster
args = [(im_n, im, tgt_srf, n_clusters, tgt_n_samples) for im_n, im in enumerate(self.ims_ref)] args = [(im, tgt_srf, n_clusters, tgt_n_samples) for im in self.ims_ref]
if self.CPUs > 1:
processes = len(self.ims_ref) if self.CPUs > len(self.ims_ref) else self.CPUs
with tempfile.TemporaryDirectory() as tmpdir:
self.tmpdir_multiproc = tmpdir
if self.CPUs is None or self.CPUs > 1: with Pool(processes) as pool:
with Pool(processes=len(self.ims_ref), initializer=mp_initializer_SpecHomo, pool.starmap(self._get_uniform_random_samples, args)
initargs=(self.ref_cube,)) as pool:
pool.starmap(self._get_uniform_random_samples, args) # combine temporarily saved random samples to ref_cube
self.ref_cube = global_shared_ref_cube self.logger.info('Combining random samples to reference cube...')
self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands)))
for im_n, im in enumerate(self.ims_ref):
path_randsampl = os.path.join(tmpdir, 'random_samples', os.path.basename(im))
self.logger.info('Adding content of %s to reference cube...' % im)
self.ref_cube[:, im_n, :] = GeoArray(path_randsampl)[:]
else: else:
for argset in args: self.ref_cube = np.zeros((tgt_n_samples, len(self.ims_ref), len(tgt_srf.bands)))
self._get_uniform_random_samples(*argset) for im_n, argset in enumerate(args):
# combine returned random samples to ref_cube
random_samples = self._get_uniform_random_samples(*argset)
self.logger.info('Adding random samples for %s to reference cube...' % argset[0])
self.ref_cube[:, im_n, :] = random_samples
# save # save
if path_out: if path_out:
...@@ -400,14 +403,13 @@ class SpecHomo_Classifier(object): ...@@ -400,14 +403,13 @@ class SpecHomo_Classifier(object):
return self.ref_cube return self.ref_cube
def _get_uniform_random_samples(self, im_n, im_ref, tgt_srf, n_clusters, tgt_n_samples): def _get_uniform_random_samples(self, im_ref, tgt_srf, n_clusters, tgt_n_samples):
im_name = os.path.basename(im_ref) im_name = os.path.basename(im_ref)
# read input image # read input image
self.logger.info('Reading the input image %s...' % im_name) self.logger.info('Reading the input image %s...' % im_name)
im_gA = GeoArray(im_ref) im_gA = GeoArray(im_ref)
im_gA.cwl = np.array(im_gA.meta.loc['wavelength'], dtype=np.float).flatten() im_gA.cwl = np.array(im_gA.meta.loc['wavelength'], dtype=np.float).flatten()
# im_gA.to_mem()
wvl_unit = 'nanometers' if max(im_gA.cwl) > 15 else 'micrometers' wvl_unit = 'nanometers' if max(im_gA.cwl) > 15 else 'micrometers'
# perform spectral resampling of input image to match spectral properties of target sensor # perform spectral resampling of input image to match spectral properties of target sensor
...@@ -415,7 +417,8 @@ class SpecHomo_Classifier(object): ...@@ -415,7 +417,8 @@ class SpecHomo_Classifier(object):
SR = SpectralResampler(im_gA.cwl, tgt_srf, wvl_unit=wvl_unit) SR = SpectralResampler(im_gA.cwl, tgt_srf, wvl_unit=wvl_unit)
im_tgt = np.empty((*im_gA.shape[:2], len(tgt_srf.bands))) im_tgt = np.empty((*im_gA.shape[:2], len(tgt_srf.bands)))
for ((rS, rE), (cS, cE)), tiledata in tqdm(im_gA.tiles((1000, 1000))): tiles = im_gA.tiles((1000, 1000))
for ((rS, rE), (cS, cE)), tiledata in (tqdm(tiles) if self.CPUs == 1 else tiles):
im_tgt[rS: rE + 1, cS: cE + 1, :] = SR.resample_image(tiledata) im_tgt[rS: rE + 1, cS: cE + 1, :] = SR.resample_image(tiledata)
im_tgt = GeoArray(im_tgt, im_gA.gt, im_gA.prj) im_tgt = GeoArray(im_tgt, im_gA.gt, im_gA.prj)
...@@ -434,11 +437,8 @@ class SpecHomo_Classifier(object): ...@@ -434,11 +437,8 @@ class SpecHomo_Classifier(object):
# combine the spectra (2D arrays) of all clusters to a single 2D array # combine the spectra (2D arrays) of all clusters to a single 2D array
random_samples = np.vstack([random_samples[clusterlabel] for clusterlabel in random_samples]) random_samples = np.vstack([random_samples[clusterlabel] for clusterlabel in random_samples])
# reshape it so that we have the spectral information in 3rd dimension (x rows, 1 column and z bands) # return random samples or cache them on disk in multiprocessing
#random_samples = random_samples.reshape((random_samples.shape[0], 1, random_samples.shape[1])) if self.CPUs > 1:
GeoArray(random_samples, nodata=-9999).save(os.path.join(self.tmpdir_multiproc, 'random_samples', im_name))
# copy resulting random samples into self.ref_cube
if self.CPUs is None or self.CPUs > 1:
global_shared_ref_cube[:, im_n, :] = random_samples
else: else:
self.ref_cube[:, im_n, :] = random_samples return random_samples
...@@ -25,3 +25,4 @@ sqlalchemy ...@@ -25,3 +25,4 @@ sqlalchemy
psycopg2 psycopg2
# fmask # not pip installable # fmask # not pip installable
six six
tqdm
...@@ -11,3 +11,4 @@ pyinstrument ...@@ -11,3 +11,4 @@ pyinstrument
geoalchemy2 geoalchemy2
sqlalchemy sqlalchemy
six six
tqdm
...@@ -14,7 +14,7 @@ with open('HISTORY.rst') as history_file: ...@@ -14,7 +14,7 @@ with open('HISTORY.rst') as history_file:
requirements = [ requirements = [
'matplotlib', 'numpy', 'scikit-learn', 'scipy', 'gdal', 'pyproj', 'shapely', 'ephem', 'pyorbital', 'dill', 'pytz', 'matplotlib', 'numpy', 'scikit-learn', 'scipy', 'gdal', 'pyproj', 'shapely', 'ephem', 'pyorbital', 'dill', 'pytz',
'pandas', 'numba', 'spectral>=0.16', 'geopandas', 'iso8601', 'pyinstrument', 'geoalchemy2', 'sqlalchemy', 'pandas', 'numba', 'spectral>=0.16', 'geopandas', 'iso8601', 'pyinstrument', 'geoalchemy2', 'sqlalchemy',
'psycopg2', 'py_tools_ds>=0.10.0', 'geoarray>=0.7.1', 'arosics>=0.6.6', 'six' 'psycopg2', 'py_tools_ds>=0.10.0', 'geoarray>=0.7.1', 'arosics>=0.6.6', 'six', 'tqdm'
# spectral<0.16 has some problems with writing signed integer 8bit data # spectral<0.16 has some problems with writing signed integer 8bit data
# fmask # conda install -c conda-forge python-fmask # fmask # conda install -c conda-forge python-fmask
# 'pyhdf', # conda install --yes -c conda-forge pyhdf # 'pyhdf', # conda install --yes -c conda-forge pyhdf
......
...@@ -16,10 +16,10 @@ RUN /bin/bash -i -c "source /root/anaconda3/bin/activate ; \ ...@@ -16,10 +16,10 @@ RUN /bin/bash -i -c "source /root/anaconda3/bin/activate ; \
conda install --yes -c conda-forge numpy gdal scikit-image scikit-learn matplotlib pyproj rasterio shapely basemap \ conda install --yes -c conda-forge numpy gdal scikit-image scikit-learn matplotlib pyproj rasterio shapely basemap \
pykrige glymur pygrib pyproj cachetools pyhdf ephem python-fmask scipy ; \ pykrige glymur pygrib pyproj cachetools pyhdf ephem python-fmask scipy ; \
conda install --yes -c conda-forge 'icu=58.*' lxml ; \ conda install --yes -c conda-forge 'icu=58.*' lxml ; \
pip install pandas geopandas dicttoxml jsmin cerberus pyprind pint iso8601 tqdm mpld3 sphinx-argparse dill pytz \ pip install pandas geopandas dicttoxml jsmin cerberus pyprind pint iso8601 mpld3 sphinx-argparse dill pytz \
spectral>0.16 psycopg2 pyorbital pyinstrument geoalchemy2 sqlalchemy py_tools_ds>=0.10.0 \ spectral>0.16 psycopg2 pyorbital pyinstrument geoalchemy2 sqlalchemy py_tools_ds>=0.10.0 \
geoarray>=0.7.1 arosics>=0.6.6 flake8 pycodestyle pylint pydocstyle nose nose2 nose-htmloutput \ geoarray>=0.7.1 arosics>=0.6.6 flake8 pycodestyle pylint pydocstyle nose nose2 nose-htmloutput \
coverage rednose six" # must include all the requirements needed to build the docs! coverage rednose six tqdm" # must include all the requirements needed to build the docs!
# copy some needed stuff to /root # copy some needed stuff to /root
#COPY *.pkl /root/ # EXAMPLE #COPY *.pkl /root/ # EXAMPLE
......
...@@ -33,10 +33,12 @@ class Test_SpecHomo_Classifier(unittest.TestCase): ...@@ -33,10 +33,12 @@ class Test_SpecHomo_Classifier(unittest.TestCase):
def test_generate_reference_cube_L8(self): def test_generate_reference_cube_L8(self):
ref_cube = self.SHC.generate_reference_cube('Landsat-8', 'OLI_TIRS', n_clusters=10, tgt_n_samples=1000) ref_cube = self.SHC.generate_reference_cube('Landsat-8', 'OLI_TIRS', n_clusters=10, tgt_n_samples=1000)
self.assertIsInstance(ref_cube, np.ndarray) self.assertIsInstance(ref_cube, np.ndarray)
self.assertTrue(np.any(ref_cube), msg='Reference cube for Landsat-8 is empty.')
def test_generate_reference_cube_S2A(self): def test_generate_reference_cube_S2A(self):
ref_cube = self.SHC.generate_reference_cube('Sentinel-2A', 'MSI', n_clusters=10, tgt_n_samples=1000) ref_cube = self.SHC.generate_reference_cube('Sentinel-2A', 'MSI', n_clusters=10, tgt_n_samples=1000)
self.assertIsInstance(ref_cube, np.ndarray) self.assertIsInstance(ref_cube, np.ndarray)
self.assertTrue(np.any(ref_cube), msg='Reference cube for Sentinel-2A is empty.')
@unittest.SkipTest @unittest.SkipTest
def test_generate_reference_cube_AST(self): def test_generate_reference_cube_AST(self):
...@@ -50,5 +52,5 @@ class Test_SpecHomo_Classifier(unittest.TestCase): ...@@ -50,5 +52,5 @@ class Test_SpecHomo_Classifier(unittest.TestCase):
SHC = SpecHomo_Classifier([testdata, testdata, ], v=False, CPUs=None) SHC = SpecHomo_Classifier([testdata, testdata, ], v=False, CPUs=None)
ref_cube_mp = SHC.generate_reference_cube('Landsat-8', 'OLI_TIRS', n_clusters=10, tgt_n_samples=1000) ref_cube_mp = SHC.generate_reference_cube('Landsat-8', 'OLI_TIRS', n_clusters=10, tgt_n_samples=1000)
self.assertTrue(np.array_equal(ref_cube_sp, ref_cube_mp), self.assertTrue(np.any(ref_cube_sp), msg='Singleprocessing result is empty.')
msg='Singleprocessing result is not equal to multiprocessing result.') self.assertTrue(np.any(ref_cube_mp), msg='Multiprocessing result is empty.')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment