baseclasses.py 84.4 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# geoarray, A fast Python interface for image geodata - either on disk or in memory.
#
# Copyright (C) 2019  Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de)
#
# This software was developed within the context of the GeoMultiSens project funded
# by the German Federal Ministry of Education and Research
# (project grant code: 01 IS 14 010 A-C).
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.


25
26
import os
import warnings
27
from pkgutil import find_loader
28
from collections import OrderedDict
29
from copy import copy, deepcopy
30
from typing import Union  # noqa F401
31
32

import numpy as np
33
from osgeo import gdal, gdal_array, gdalnumeric
34
35
36
37
from shapely.geometry import Polygon
from shapely.wkt import loads as shply_loads
# dill -> imported when dumping GeoArray

38
39
40
41
from py_tools_ds.convenience.object_oriented import alias_property
from py_tools_ds.geo.coord_calc import get_corner_coordinates
from py_tools_ds.geo.coord_grid import snap_bounds_to_pixGrid
from py_tools_ds.geo.coord_trafo import mapXY2imXY, imXY2mapXY, transform_any_prj, reproject_shapelyGeometry
42
from py_tools_ds.geo.projection import prj_equal, WKT2EPSG, EPSG2WKT, isLocal, CRS
43
44
from py_tools_ds.geo.raster.conversion import raster2polygon
from py_tools_ds.geo.vector.topology \
45
    import get_footprint_polygon, polyVertices_outside_poly, fill_holes_within_poly
46
47
from py_tools_ds.geo.vector.geometry import boxObj
from py_tools_ds.io.raster.gdal import get_GDAL_ds_inmem
48
from py_tools_ds.numeric.numbers import is_number
49
from py_tools_ds.numeric.array import get_array_tilebounds
50
51
52

#  internal imports
from .subsetting import get_array_at_mapPos
53
from .metadata import GDAL_Metadata
54
55

__author__ = 'Daniel Scheffler'
56
57
58


class GeoArray(object):
59
60
61
62
63
64
    """A class providing a fast Python interface for geodata - either on disk or in memory.

    GeoArray can be instanced with a file path or with a numpy array and the corresponding geoinformation. Instances
    can always be indexed and sliced like normal numpy arrays, no matter if it has been instanced from file or from an
    in-memory array. GeoArray provides a wide range of geo-related attributes belonging to the dataset as well as
    some functions for quickly visualizing the data as a map, a simple image or an interactive image.
Daniel Scheffler's avatar
Daniel Scheffler committed
65
    """
66

67
68
    def __init__(self, path_or_array, geotransform=None, projection=None, bandnames=None, nodata=None, progress=True,
                 q=False):
69
        # type: (Union[str, np.ndarray, GeoArray], tuple, str, list, float, bool, bool) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
70
        """Get an instance of GeoArray.
71
72
73
74
75
76
77
78
79
80
81
82

        :param path_or_array:   a numpy.ndarray or a valid file path
        :param geotransform:    GDAL geotransform of the given array or file on disk
        :param projection:      projection of the given array or file on disk as WKT string
                                (only needed if GeoArray is instanced with an array)
        :param bandnames:       names of the bands within the input array, e.g. ['mask_1bit', 'mask_clouds'],
                                (default: ['B1', 'B2', 'B3', ...])
        :param nodata:          nodata value
        :param progress:        show progress bars (default: True)
        :param q:               quiet mode (default: False)
        """
        if not (isinstance(path_or_array, (str, np.ndarray, GeoArray)) or
83
           issubclass(getattr(path_or_array, '__class__'), GeoArray)):
84
            raise ValueError("%s parameter 'arg' takes only string, np.ndarray or GeoArray(and subclass) instances. "
85
                             "Got %s." % (self.__class__.__name__, type(path_or_array)))
86
87

        if path_or_array is None:
88
            raise ValueError("The %s parameter 'path_or_array' must not be None!" % self.__class__.__name__)
89
90
91
92
93

        if isinstance(path_or_array, str):
            assert ' ' not in path_or_array, "The given path contains whitespaces. This is not supported by GDAL."

            if not os.path.exists(path_or_array):
94
                raise FileNotFoundError(path_or_array)
95

96
97
        if isinstance(path_or_array, GeoArray) or issubclass(getattr(path_or_array, '__class__'), GeoArray):
            self.__dict__ = path_or_array.__dict__.copy()
98
            self._initParams = dict([x for x in locals().items() if x[0] != "self"])
99
100
            self.geotransform = geotransform or self.geotransform
            self.projection = projection or self.projection
101
            self.bandnames = bandnames or list(self.bandnames.keys())
102
103
104
            self._nodata = nodata if nodata is not None else self._nodata
            self.progress = False if progress is False else self.progress
            self.q = q if q is not None else self.q
105
106

        else:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
            self._initParams = dict([x for x in locals().items() if x[0] != "self"])
            self.arg = path_or_array
            self._arr = path_or_array if isinstance(path_or_array, np.ndarray) else None
            self.filePath = path_or_array if isinstance(path_or_array, str) and path_or_array else None
            self.basename = os.path.splitext(os.path.basename(self.filePath))[0] if not self.is_inmem else 'IN_MEM'
            self.progress = progress
            self.q = q
            self._arr_cache = None  # dict containing key 'pos' and 'arr_cached'
            self._geotransform = None
            self._projection = None
            self._shape = None
            self._dtype = None
            self._nodata = nodata
            self._mask_nodata = None
            self._mask_baddata = None
122
123
            self._footprint_poly = None
            self._gdalDataset_meta_already_set = False
124
125
            self._metadata = None
            self._bandnames = None
126
127

            if bandnames:
128
                self.bandnames = bandnames  # use property in order to validate given value
129
            if geotransform:
130
                self.geotransform = geotransform  # use property in order to validate given value
131
            if projection:
132
                self.projection = projection  # use property in order to validate given value
133
134
135
136

            if self.filePath:
                self.set_gdalDataset_meta()

137
138
139
140
141
            if 'nodata' in self._initParams and self._initParams['nodata'] is not None:
                self._validate_nodataVal()

    def _validate_nodataVal(self):
        """Check if a given nodata value is within the valid value range of the data type."""
142
143
144
145
146
147
148
149
150
151
        _nodata = self._initParams['nodata']

        if np.issubdtype(self.dtype, np.integer):
            dt_min, dt_max = np.iinfo(self.dtype).min, np.iinfo(self.dtype).max
        elif np.issubdtype(self.dtype, np.floating):
            dt_min, dt_max = np.finfo(self.dtype).min, np.finfo(self.dtype).max
        else:
            return

        if not dt_min <= _nodata <= dt_max:
152
            raise ValueError("The given no-data value (%s) is out range for data type %s."
153
                             % (self._initParams['nodata'], str(np.dtype(self.dtype))))
154

155
156
157
158
159
160
    @property
    def arr(self):
        return self._arr

    @arr.setter
    def arr(self, ndarray):
161
162
        assert isinstance(ndarray, np.ndarray), "'arr' can only be set to a numpy array! Got %s." % type(ndarray)
        # assert ndarray.shape == self.shape, "'arr' can only be set to a numpy array with shape %s. Received %s. " \
163
164
165
166
167
        #                                    "If you need to change the dimensions, create a new instance of %s." \
        #                                    %(self.shape, ndarray.shape, self.__class__.__name__)
        #  THIS would avoid warping like this: geoArr.arr, geoArr.gt, geoArr.prj = warp(...)

        if ndarray.shape != self.shape:
168
            self.flush_cache()  # the cached array is not useful anymore
169
170
171
172
173

        self._arr = ndarray

    @property
    def bandnames(self):
174
        if self._bandnames and len(self._bandnames) == self.bands:
175
176
            return self._bandnames
        else:
177
            del self.bandnames  # runs deleter which sets it to default values
178
179
180
181
182
183
184
            return self._bandnames

    @bandnames.setter
    def bandnames(self, list_bandnames):
        # type: (list) -> None

        if list_bandnames:
185
186
187
188
189
190
191
192
            if not isinstance(list_bandnames, list):
                raise TypeError("A list must be given when setting the 'bandnames' attribute. "
                                "Received %s." % type(list_bandnames))
            if len(list_bandnames) != self.bands:
                raise ValueError('Number of given bandnames does not match number of bands in array.')
            if len(list(set([type(b) for b in list_bandnames]))) != 1 or not isinstance(list_bandnames[0], str):
                raise ValueError("'bandnames must be a set of strings. Got other datatypes in there.'")

193
            bN_dict = OrderedDict((band, i) for i, band in enumerate(list_bandnames))
194
195

            if len(bN_dict) != self.bands:
196
                raise ValueError('Bands must have unique names. Received band list: %s' % list_bandnames)
197
198
199

            self._bandnames = bN_dict

200
            try:
201
                self.metadata.band_meta['band_names'] = list_bandnames
202
203
204
            except AttributeError:
                # in case self._metadata is None
                pass
205
206
207
208
209
210
211
        else:
            del self.bandnames

    @bandnames.deleter
    def bandnames(self):
        self._bandnames = OrderedDict(('B%s' % band, i) for i, band in enumerate(range(1, self.bands + 1)))
        if self._metadata is not None:
212
            self.metadata.band_meta['band_names'] = list(self._bandnames.keys())
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    @property
    def is_inmem(self):
        """Check if associated image array is completely loaded into memory."""
        return isinstance(self.arr, np.ndarray)

    @property
    def shape(self):
        """Get the array shape of the associated image array."""
        if self.is_inmem:
            return self.arr.shape
        else:
            if self._shape:
                return self._shape
            else:
                self.set_gdalDataset_meta()
                return self._shape

    @property
    def ndim(self):
        """Get the number dimensions of the associated image array."""
        return len(self.shape)

    @property
    def rows(self):
        """Get the number of rows of the associated image array."""
        return self.shape[0]

    @property
    def columns(self):
        """Get the number of columns of the associated image array."""
        return self.shape[1]

    cols = alias_property('columns')

    @property
    def bands(self):
        """Get the number of bands of the associated image array."""
251
        return self.shape[2] if len(self.shape) > 2 else 1
252
253
254
255
256
257
258
259
260
261
262
263
264
265

    @property
    def dtype(self):
        """Get the numpy data type of the associated image array."""
        if self._dtype:
            return self._dtype
        elif self.is_inmem:
            return self.arr.dtype
        else:
            self.set_gdalDataset_meta()
            return self._dtype

    @property
    def geotransform(self):
266
        """Get the GDAL GeoTransform of the associated image, e.g., (283500.0, 5.0, 0.0, 4464500.0, 0.0, -5.0)."""
267
268
269
270
271
272
        if self._geotransform:
            return self._geotransform
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
            return self._geotransform
        else:
273
            return [0, 1, 0, 0, 0, -1]
274
275
276

    @geotransform.setter
    def geotransform(self, gt):
277
        # type: (Union[list, tuple]) -> None
278
279
        assert isinstance(gt, (list, tuple)) and len(gt) == 6,\
            'geotransform must be a list with 6 numbers. Got %s.' % str(gt)
280

281
        for i in gt:
282
            assert is_number(i), "geotransform must contain only numbers. Got '%s' (type: %s)." % (i, type(i))
283

284
        self._geotransform = gt
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299

    gt = alias_property('geotransform')

    @property
    def xgsd(self):
        """Get the X resolution in units of the given or detected projection."""
        return self.geotransform[1]

    @property
    def ygsd(self):
        """Get the Y resolution in units of the given or detected projection."""
        return abs(self.geotransform[5])

    @property
    def xygrid_specs(self):
300
        """Get the specifications for the X/Y coordinate grid.
301

302
303
304
        This returns for example [[15,30], [0,30]] for a coordinate
        with its origin at X/Y[15,0] and a GSD of X/Y[15,30].
        """
305
        def get_grid(gt, xgsd, ygsd): return [[gt[0], gt[0] + xgsd], [gt[3], gt[3] - ygsd]]
306
307
308
309
        return get_grid(self.geotransform, self.xgsd, self.ygsd)

    @property
    def projection(self):
310
        """Get the projection of the associated image.
311

312
313
314
        Setting the projection is only allowed if GeoArray has been instanced from memory or the associated file on
        disk has no projection.
        """
315
316
317
318
        if self._projection:
            return self._projection
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
319
            return self._projection  # or "LOCAL_CS[\"MAP\"]"
320
        else:
321
            return ''  # '"LOCAL_CS[\"MAP\"]"
322
323
324

    @projection.setter
    def projection(self, prj):
325
        # type: (str) -> None
326
        self._projection = prj
327
328
329
330
331

    prj = alias_property('projection')

    @property
    def epsg(self):
332
        # type: () -> int
333
334
335
336
337
        """Get the EPSG code of the projection of the GeoArray."""
        return WKT2EPSG(self.projection)

    @epsg.setter
    def epsg(self, epsg_code):
338
        # type: (int) -> None
339
340
341
342
343
344
345
        self.projection = EPSG2WKT(epsg_code)

    @property
    def box(self):
        mapPoly = get_footprint_polygon(get_corner_coordinates(gt=self.geotransform, cols=self.columns, rows=self.rows))
        return boxObj(gt=self.geotransform, prj=self.projection, mapPoly=mapPoly)

346
347
348
    @property
    def is_map_geo(self):
        # type: () -> bool
349
        """Return 'True' if the image has a valid geoinformation with map instead of image coordinates."""
350
        return all([self.gt, list(self.gt) != [0, 1, 0, 0, 0, -1], self.prj])
351

352
353
    @property
    def nodata(self):
354
        """Get the nodata value of the GeoArray instance.
355

356
357
358
359
        If GeoArray has been instanced with a file path the metadata of the file on disk is checked for an existing
        nodata value. Otherwise (if no value is exlicitly given during object instanciation) an automatic detection
        based on 3x3 windows at each image corner is run that analyzes the mean and standard deviation of these windows.
        """
360
361
362
363
364
365
366
        if self._nodata is not None:
            return self._nodata
        else:
            # try to get nodata value from file
            if not self.is_inmem:
                self.set_gdalDataset_meta()
            if self._nodata is None:
367
                self.find_noDataVal()
368
369
370
371
372
373
                if self._nodata == 'ambiguous':
                    warnings.warn('Nodata value could not be clearly identified. It has been set to None.')
                    self._nodata = None
                else:
                    if self._nodata is not None and not self.q:
                        print("Automatically detected nodata value for %s '%s': %s"
374
                              % (self.__class__.__name__, self.basename, self._nodata))
375
376
377
378
            return self._nodata

    @nodata.setter
    def nodata(self, value):
379
        # type: (Union[int, None]) -> None
380
381
        self._nodata = value

382
383
384
        if self._metadata and value is not None:
            self.metadata.global_meta.update({'data ignore value': str(value)})

385
386
    @property
    def mask_nodata(self):
387
        """Get the nodata mask of the associated image array. It is generated based on all image bands."""
388
389
390
        if self._mask_nodata is not None:
            return self._mask_nodata
        else:
391
            self.calc_mask_nodata()  # sets self._mask_nodata
392
393
394
395
            return self._mask_nodata

    @mask_nodata.setter
    def mask_nodata(self, mask):
396
        """Set the bad data mask.
397
398
399
400
401

        :param mask:    Can be a file path, a numpy array or an instance o GeoArray.
        """
        if mask is not None:
            from .masks import NoDataMask
402
403
            geoArr_mask = NoDataMask(mask, progress=self.progress, q=self.q)
            geoArr_mask.gt = geoArr_mask.gt if geoArr_mask.gt not in [None, [0, 1, 0, 0, 0, -1]] else self.gt
404
            geoArr_mask.prj = geoArr_mask.prj if geoArr_mask.prj else self.prj
405
            imName = "the %s '%s'" % (self.__class__.__name__, self.basename)
406
407
408
409

            assert geoArr_mask.bands == 1, \
                'Expected one single band as nodata mask for %s. Got %s bands.' % (self.basename, geoArr_mask.bands)
            assert geoArr_mask.shape[:2] == self.shape[:2], 'The provided nodata mask must have the same number of ' \
410
                                                            'rows and columns as the %s itself.' % imName
411
412
            assert geoArr_mask.gt == self.gt, \
                'The geotransform of the given nodata mask for %s must match the geotransform of the %s itself. ' \
413
                'Got %s.' % (imName, self.__class__.__name__, geoArr_mask.gt)
414
415
            assert not geoArr_mask.prj or prj_equal(geoArr_mask.prj, self.prj), \
                'The projection of the given nodata mask for the %s must match the projection of the %s itself.' \
416
                % (imName, self.__class__.__name__)
417
418

            self._mask_nodata = geoArr_mask
419
420
421
422
423
424
        else:
            del self.mask_nodata

    @mask_nodata.deleter
    def mask_nodata(self):
        self._mask_nodata = None
425
426
427

    @property
    def mask_baddata(self):
428
        """Return the bad data mask.
429

430
431
        Note: The mask must be explicitly set to a file path or a numpy array before.
        """
432
433
434
435
436
437
438
439
440
441
        return self._mask_baddata

    @mask_baddata.setter
    def mask_baddata(self, mask):
        """Set bad data mask.

        :param mask:    Can be a file path, a numpy array or an instance o GeoArray.
        """
        if mask is not None:
            from .masks import BadDataMask
442
443
            geoArr_mask = BadDataMask(mask, progress=self.progress, q=self.q)
            geoArr_mask.gt = geoArr_mask.gt if geoArr_mask.gt not in [None, [0, 1, 0, 0, 0, -1]] else self.gt
444
            geoArr_mask.prj = geoArr_mask.prj if geoArr_mask.prj else self.prj
445
            imName = "the %s '%s'" % (self.__class__.__name__, self.basename)
446
447
448
449

            assert geoArr_mask.bands == 1, \
                'Expected one single band as bad data mask for %s. Got %s bands.' % (self.basename, geoArr_mask.bands)
            assert geoArr_mask.shape[:2] == self.shape[:2], 'The provided bad data mask must have the same number of ' \
450
                                                            'rows and columns as the %s itself.' % imName
451
452
            assert geoArr_mask.gt == self.gt, \
                'The geotransform of the given bad data mask for %s must match the geotransform of the %s itself. ' \
453
                'Got %s.' % (imName, self.__class__.__name__, geoArr_mask.gt)
454
455
            assert prj_equal(geoArr_mask.prj, self.prj), \
                'The projection of the given bad data mask for the %s must match the projection of the %s itself.' \
456
                % (imName, self.__class__.__name__)
457
458

            self._mask_baddata = geoArr_mask
459
460
461
462
463
464
        else:
            del self.mask_baddata

    @mask_baddata.deleter
    def mask_baddata(self):
        self._mask_baddata = None
465
466
467

    @property
    def footprint_poly(self):
468
        """Get the footprint polygon of the associated image array (shapely.geometry.Polygon)."""
469
470
471
        # FIXME should return polygon in image coordinates if no projection is available
        if self._footprint_poly is None:
            assert self.mask_nodata is not None, 'A nodata mask is needed for calculating the footprint polygon. '
472
            if False not in self.mask_nodata[:]:
473
474
475
476
                # do not run raster2polygon if whole image is filled with data
                self._footprint_poly = self.box.mapPoly
            else:
                try:
477
                    multipolygon = raster2polygon(self.mask_nodata.astype(np.uint8), self.gt, self.prj, exact=False,
478
                                                  progress=self.progress, q=self.q, maxfeatCount=10, timeout=15)
479
                    self._footprint_poly = fill_holes_within_poly(multipolygon)
480
                except (RuntimeError, TimeoutError):
481
482
483
484
                    if not self.q:
                        warnings.warn("\nCalculation of footprint polygon failed for %s '%s'. Using outer bounds. One "
                                      "reason could be that the nodata value appears within the actual image (not only "
                                      "as fill value). To avoid this use another nodata value. Current nodata value is "
485
                                      "%s." % (self.__class__.__name__, self.basename, self.nodata))
486
487
488
                    self._footprint_poly = self.box.mapPoly

            # validation
489
            assert not polyVertices_outside_poly(self._footprint_poly, self.box.mapPoly, tolerance=1e-5), \
490
491
492
                "Computing footprint polygon for %s '%s' failed. The resulting polygon is partly or completely " \
                "outside of the image bounds." % (self.__class__.__name__, self.basename)
            # assert self._footprint_poly
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
            # for XY in self.corner_coord:
            #    assert self.GeoArray.box.mapPoly.contains(Point(XY)) or self.GeoArray.box.mapPoly.touches(Point(XY)), \
            #        "The corner position '%s' is outside of the %s." % (XY, self.imName)

        return self._footprint_poly

    @footprint_poly.setter
    def footprint_poly(self, poly):
        if isinstance(poly, Polygon):
            self._footprint_poly = poly
        elif isinstance(poly, str):
            self._footprint_poly = shply_loads(poly)
        else:
            raise ValueError("'footprint_poly' can only be set from a shapely polygon or a WKT string.")

    @property
    def metadata(self):
510
511
        """Return a DataFrame containing all available metadata (read from file if available).

512
513
514
515
        Use 'metadata[band_index].to_dict()' to get a metadata dictionary for a specific band.
        Use 'metadata.loc[row_name].to_dict()' to get all metadata values of the same key for all bands as dictionary.
        Use 'metadata.loc[row_name, band_index] = value' to set a new value.

516
        :return:  pandas.DataFrame
517
518
519
520
        """
        if self._metadata is not None:
            return self._metadata
        else:
521
            default = GDAL_Metadata(nbands=self.bands, nodata_allbands=self._nodata)
522

523
524
525
526
527
528
529
530
            self._metadata = default
            if not self.is_inmem:
                self.set_gdalDataset_meta()
                return self._metadata
            else:
                return self._metadata

    @metadata.setter
531
532
533
534
535
536
    def metadata(self, meta):
        if not isinstance(meta, GDAL_Metadata) or meta.bands != self.bands:
            raise ValueError("%s.metadata can only be set with an instance of geoarray.metadata.GDAL_Metadata of "
                             "which the band number corresponds to the band number of %s."
                             % (self.__class__.__name__, self.__class__.__name__))
        self._metadata = meta
537
538
539
540

    meta = alias_property('metadata')

    def __getitem__(self, given):
541
        if isinstance(given, (int, float, slice, np.integer, np.floating)) and self.ndim == 3:
542
543
544
545
546
547
548
549
550
551
552
553
554
            # handle 'given' as index for 3rd (bands) dimension
            if self.is_inmem:
                return self.arr[:, :, given]
            else:
                return self.from_path(self.arg, [given])

        elif isinstance(given, str):
            # behave like a dictionary and return the corresponding band
            if self.bandnames:
                if given not in self.bandnames:
                    raise ValueError("'%s' is not a known band. Known bands are: %s"
                                     % (given, ', '.join(list(self.bandnames.keys()))))
                if self.is_inmem:
555
                    return self.arr if self.ndim == 2 else self.arr[:, :, self.bandnames[given]]
556
557
558
559
                else:
                    return self.from_path(self.arg, [self.bandnames[given]])
            else:
                raise ValueError('String indices are only supported if %s has been instanced with bandnames given.'
560
                                 % self.__class__.__name__)
561
562
563
564

        elif isinstance(given, (tuple, list)):
            # handle requests like geoArr[[1,2],[3,4]  -> not implemented in from_path if array is not in mem
            types = [type(i) for i in given]
Daniel Scheffler's avatar
Daniel Scheffler committed
565

566
            if list in types or tuple in types:
567
568
569
570
571
572
573
574

                # avoid that the whole cube is read if only data from a single band is requested
                if not self.is_inmem \
                   and len(given) == 3 \
                   and isinstance(given[2], (int, float, np.integer, np.floating)):
                    band_subset = GeoArray(self.filePath)[:, :, given[2]]
                    return band_subset[given[:2]]

575
576
                self.to_mem()

577
            if len(given) == 3:
578
579

                # handle strings in the 3rd dim of 'given' -> convert them to a band index
580
                if isinstance(given[2], str):
581
582
583
584
585
586
587
588
                    if self.bandnames:
                        if given[2] not in self.bandnames:
                            raise ValueError("'%s' is not a known band. Known bands are: %s"
                                             % (given[2], ', '.join(list(self.bandnames.keys()))))

                        band_idx = self.bandnames[given[2]]
                        # NOTE: the string in the 3rd is ignored if ndim==2 and band_idx==0
                        if self.is_inmem:
589
                            return self.arr if (self.ndim == 2 and band_idx == 0) else self.arr[:, :, band_idx]
590
                        else:
591
592
                            getitem_params = \
                                given[:2] if (self.ndim == 2 and band_idx == 0) else given[:2] + (band_idx,)
593
594
595
596
597
598
599
                            return self.from_path(self.arg, getitem_params)
                    else:
                        raise ValueError(
                            'String indices are only supported if %s has been instanced with bandnames given.'
                            % self.__class__.__name__)

                # in case a third dim is requested from 2D-array -> ignore 3rd dim if 3rd dim is 0
600
                elif self.ndim == 2 and given[2] == 0:
601
602
603
604
605
606
607
608
609
610
611
612
613
                    if self.is_inmem:
                        return self.arr[given[:2]]
                    else:
                        return self.from_path(self.arg, given[:2])

        # if nothing has been returned until here -> behave like a numpy array
        if self.is_inmem:
            return self.arr[given]
        else:
            getitem_params = [given] if isinstance(given, slice) else given
            return self.from_path(self.arg, getitem_params)

    def __setitem__(self, idx, array2set):
614
        """Overwrite the pixel values of GeoArray.arr with the given array.
615
616
617
618
619
620
621
622

        :param idx:         <int, list, slice> the index position to overwrite
        :param array2set:   <np.ndarray> array to be set. Must be compatible to the given index position.
        """
        if self.is_inmem:
            self.arr[idx] = array2set
        else:
            raise NotImplementedError('Item assignment for %s instances that are not in memory is not yet supported.'
623
                                      % self.__class__.__name__)
624
625
626

    def __getattr__(self, attr):
        # check if the requested attribute can not be present because GeoArray has been instanced with an array
627
628
        attrsNot2Link2np = ['__deepcopy__']   # attributes we don't want to inherit from numpy.ndarray

629
630
        if attr not in self.__dir__() and not self.is_inmem and attr in ['shape', 'dtype', 'geotransform',
                                                                         'projection']:
631
632
            self.set_gdalDataset_meta()

633
634
        if attr in self.__dir__():  # __dir__() includes also methods and properties
            return self.__getattribute__(attr)  # __getattribute__ avoids infinite loop
635
        elif attr not in attrsNot2Link2np and hasattr(np.array([]), attr):
636
637
            return self[:].__getattribute__(attr)
        else:
638
            raise AttributeError("%s object has no attribute '%s'." % (self.__class__.__name__, attr))
639
640

    def __getstate__(self):
641
        """Define how the attributes of the GeoArray instance are pickled (e.g., by multiprocessing.Pool)."""
642
643
644
645
646
647
        # clean array cache in order to avoid cache pickling
        self.flush_cache()

        return self.__dict__

    def __setstate__(self, state):
648
649
        """Define how the attributes of the GeoArray instance are unpickled (e.g., by multiprocessing.Pool).

650
651
652
653
654
        NOTE: This method has been implemented because otherwise pickled and unpickled instances show recursion errors
        within __getattr__ when requesting any attribute.
        """
        self.__dict__ = state

655
656
    def calc_mask_nodata(self, fromBand=None, overwrite=False, flag='all'):
        # type: (int, bool, str) -> np.ndarray
657
        """Calculate a no data mask with values False (=nodata) and True (=data).
658
659
660
661
662
663

        :param fromBand:   index of the band to be used (if None, all bands are used)
        :param overwrite:  whether to overwrite existing nodata mask that has already been calculated
        :param flag:       algorithm how to flag pixels (default: 'all')
                           'all': flag those pixels as nodata that contain the nodata value in ALL bands
                           'any': flag those pixels as nodata that contain the nodata value in ANY band
664
665
666
        :return:
        """
        if self._mask_nodata is None or overwrite:
667
668
669
            if flag not in ['all', 'any']:
                raise ValueError(flag)

670
            assert self.ndim in [2, 3], "Only 2D or 3D arrays are supported. Got a %sD array." % self.ndim
671
            arr = self[:, :, fromBand] if self.ndim == 3 and fromBand is not None else self[:]
672

673
            if self.nodata is None:
674
                mask = np.ones((self.rows, self.cols), bool)
675
676
677
678
679
680
681
682
683
684

            elif np.isnan(self.nodata):
                nanmask = np.isnan(arr)
                nanbands = np.all(np.all(nanmask, axis=0), axis=0)

                if np.all(nanbands):
                    mask = np.full(arr.shape[:2], False)
                elif arr.ndim == 2:
                    mask = ~np.isnan(arr)
                else:
685
686
687
688
689
690
691
692
693
694
695
696
697
                    arr_1st_databand = arr[:, :, np.argwhere(~nanbands)[0][0]]
                    arr_remain = arr[:, :, ~nanbands][:, :, 1:]

                    mask = ~np.isnan(arr_1st_databand)  # True where 1st data band has data

                    if flag == 'all':
                        # ALL bands need to contain np.nan to flag the mask as nodata
                        # overwrite the mask at nodata positions (False) with True in case there is data in ANY band
                        mask[~mask] = np.any(~np.isnan(arr_remain[~mask]), axis=1)
                    else:
                        # ANY band needs to contain np.nan to flag the mask as nodata
                        # overwrite the mask at data positions (True) with False in case there is np.nan in ANY band
                        mask[mask] = ~np.any(np.isnan(arr_remain[mask]), axis=1)
698

699
            else:
700
                bandmeans = np.mean(np.mean(arr, axis=0), axis=0)
701
                nodatabands = bandmeans == self.nodata
702
703
704
705
706

                if np.nanmean(bandmeans) == self.nodata:
                    mask = np.full(arr.shape[:2], False)
                elif arr.ndim == 2:
                    mask = arr != self.nodata
707
                else:
708
709
710
711
712
713
714
715
716
717
718
719
720
                    arr_1st_databand = arr[:, :, np.argwhere(~nodatabands)[0][0]]
                    arr_remain = arr[:, :, ~nodatabands][:, :, 1:]

                    mask = np.array(arr_1st_databand != self.nodata)  # True where 1st data band has data

                    if flag == 'all':
                        # ALL bands need to contain nodata to flag the mask as such
                        # overwrite the mask at nodata positions (False) with True in case there is data in ANY band
                        mask[~mask] = np.any(arr_remain[~mask] != self.nodata, axis=1)
                    else:
                        # ANY band needs to contain nodata to flag the mask as such
                        # overwrite the mask at data positions (True) with False in case there is nodata in ANY band
                        mask[mask] = ~np.any(arr_remain[mask] == self.nodata, axis=1)
721
722
723
724

            self.mask_nodata = mask

            return mask
725

726
    def find_noDataVal(self, bandIdx=0, sz=3):
727
728
        """Try to derive no data value from homogenious corner pixels within 3x3 windows (by default).

729
730
731
732
733
734
        :param bandIdx:
        :param sz: window size in which corner pixels are analysed
        """
        wins = [self[0:sz, 0:sz, bandIdx], self[0:sz, -sz:, bandIdx],
                self[-sz:, -sz:, bandIdx], self[-sz:, 0:sz, bandIdx]]  # UL, UR, LR, LL

735
736
        means, stds = [np.mean(win) for win in wins], [np.std(win) for win in wins]
        possVals = [mean for mean, std in zip(means, stds) if std == 0 or np.isnan(std)]
737
738
739
740
        # possVals==[]: all corners are filled with data; np.std(possVals)==0: noDataVal clearly identified

        if possVals:
            if np.std(possVals) != 0:
741
742
743
744
745
746
                if np.isnan(np.std(possVals)):
                    # at least one of the possible values is np.nan
                    nodata = np.nan
                else:
                    # different possible nodata values have been found in the image corner
                    nodata = 'ambiguous'
747
748
749
750
751
            else:
                if len(possVals) <= 2:
                    # each window in each corner
                    warnings.warn("\nAutomatic nodata value detection returned the value %s for GeoArray '%s' but this "
                                  "seems to be unreliable (occurs in only %s). To avoid automatic detection, just pass "
752
753
754
                                  "the correct nodata value."
                                  % (possVals[0], self.basename, ('2 image corners' if len(possVals) == 2 else
                                                                  '1 image corner')))
755
                nodata = possVals[0]
756
        else:
757
758
            nodata = None

759
        self.nodata = nodata
760
        return nodata
761

762
    def set_gdalDataset_meta(self):
763
        """Retrieve GDAL metadata from file.
764

765
766
        This is only executed once to avoid overwriting of user defined attributes,
        that are defined after object instanciation.
767
768
769
770
        """
        if not self._gdalDataset_meta_already_set:
            assert self.filePath
            ds = gdal.Open(self.filePath)
771
772
773
            if not ds:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())

774
            # set private class variables (in order to avoid recursion error)
775
776
            self._shape = tuple([ds.RasterYSize, ds.RasterXSize] + ([ds.RasterCount] if ds.RasterCount > 1 else []))
            self._dtype = gdal_array.GDALTypeCodeToNumericTypeCode(ds.GetRasterBand(1).DataType)
777
            self._geotransform = list(ds.GetGeoTransform())
778
779

            # for some reason GDAL reads arbitrary geotransforms as (0, 1, 0, 0, 0, 1) instead of (0, 1, 0, 0, 0, -1)
780
            self._geotransform[5] = -abs(self._geotransform[5])  # => force ygsd to be negative
781

782
            # consequently use WKT1 strings here as GDAL always exports transformation results as WKT1
783
            wkt = ds.GetProjection()
784
            self._projection = CRS(wkt).to_wkt(version="WKT1_GDAL") if not isLocal(wkt) else ''
785

786
787
788
            if 'nodata' not in self._initParams or self._initParams['nodata'] is None:
                band = ds.GetRasterBand(1)
                # FIXME this does not support different nodata values within the same file
789
                self.nodata = band.GetNoDataValue()
790

791
792
793
            # set metadata attribute
            if self.is_inmem or not self.filePath:
                # metadata cannot be read from disk -> set it to the default
794
                self._metadata = GDAL_Metadata(nbands=self.bands, nodata_allbands=self._nodata)
795

796
797
            else:
                self._metadata = GDAL_Metadata(filePath=self.filePath)
798

799
800
801
802
803
804
            # copy over the band names
            if 'band_names' in self.metadata.band_meta and self.metadata.band_meta['band_names']:
                self.bandnames = self.metadata.band_meta['band_names']

            # noinspection PyUnusedLocal
            ds = None
805
806
807
808
809
810

        self._gdalDataset_meta_already_set = True

    def from_path(self, path, getitem_params=None):
        # type: (str, list) -> np.ndarray
        """Read a GDAL compatible raster image from disk, with respect to the given image position.
811

812
813
814
815
816
817
818
        NOTE: If the requested array position is already in cache, it is returned from there.

        :param path:            <str> the file path of the image to read
        :param getitem_params:  <list> a list of slices in the form [row_slice, col_slice, band_slice]
        :return out_arr:        <np.ndarray> the output array
        """
        ds = gdal.Open(path)
819
820
821
        if not ds:
            raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())

822
        R, C, B = ds.RasterYSize, ds.RasterXSize, ds.RasterCount
823
        del ds
824

825
        # convert getitem_params to subset area to be read #
826
827
828
829
        rS, rE, cS, cE, bS, bE, bL = [None] * 7

        # populate rS, rE, cS, cE, bS, bE, bL
        if getitem_params:
830
            # populate rS, rE, cS, cE
831
832
833
834
835
            if len(getitem_params) >= 2:
                givenR, givenC = getitem_params[:2]
                if isinstance(givenR, slice):
                    rS = givenR.start
                    rE = givenR.stop - 1 if givenR.stop is not None else None
836
                elif isinstance(givenR, (int, np.integer)):
837
838
839
840
841
                    rS = givenR
                    rE = givenR
                if isinstance(givenC, slice):
                    cS = givenC.start
                    cE = givenC.stop - 1 if givenC.stop is not None else None
842
                elif isinstance(givenC, (int, np.integer)):
843
844
                    cS = givenC
                    cE = givenC
845
846

            # populate bS, bE, bL
847
848
849
850
851
            if len(getitem_params) in [1, 3]:
                givenB = getitem_params[2] if len(getitem_params) == 3 else getitem_params[0]
                if isinstance(givenB, slice):
                    bS = givenB.start
                    bE = givenB.stop - 1 if givenB.stop is not None else None
852
                elif isinstance(givenB, (int, np.integer)):
853
854
                    bS = givenB
                    bE = givenB
855
                elif isinstance(givenB, (tuple, list)):
856
857
858
                    typesInGivenB = [type(i) for i in givenB]
                    assert len(list(set(typesInGivenB))) == 1, \
                        'Mixed data types within the list of bands are not supported.'
859
                    if isinstance(givenB[0], (int, np.integer)):
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
                        bL = list(givenB)
                    elif isinstance(givenB[0], str):
                        bL = [self.bandnames[i] for i in givenB]
                elif type(givenB) in [str]:
                    bL = [self.bandnames[givenB]]

        # set defaults for not given values
        rS = rS if rS is not None else 0
        rE = rE if rE is not None else R - 1
        cS = cS if cS is not None else 0
        cE = cE if cE is not None else C - 1
        bS = bS if bS is not None else 0
        bE = bE if bE is not None else B - 1
        bL = list(range(bS, bE + 1)) if not bL else bL

        # convert negative to positive ones
        rS = rS if rS >= 0 else self.rows + rS
        rE = rE if rE >= 0 else self.rows + rE
        cS = cS if cS >= 0 else self.columns + cS
        cE = cE if cE >= 0 else self.columns + cE
        bS = bS if bS >= 0 else self.bands + bS
        bE = bE if bE >= 0 else self.bands + bE
882
        bL = [b if b >= 0 else (self.bands + b) for b in bL]
883
884

        # validate subset area bounds to be read
885
886
887
888
889
890
891
        def msg(v, idx, sz):
            # FIXME numpy raises that error ONLY for the 2nd axis
            return '%s is out of bounds for axis %s with size %s' % (v, idx, sz)

        for val, axIdx, axSize in zip([rS, rE, cS, cE, bS, bE], [0, 0, 1, 1, 2, 2], [R, R, C, C, B, B]):
            if not 0 <= val <= axSize - 1:
                raise ValueError(msg(val, axIdx, axSize))
892
893

        # summarize requested array position in arr_pos
894
        # NOTE: # bandlist must be string because truth value of an array with more than one element is ambiguous
895
896
        arr_pos = dict(rS=rS, rE=rE, cS=cS, cE=cE, bS=bS, bE=bE, bL=bL)

897
898
        def _ensure_np_shape_consistency_3D_2D(arr):
            """Ensure numpy output shape consistency according to the given indexing parameters.
899
900
901
902
903
904
905
906

            This may require 3D to 2D conversion in case out_arr can be represented by a 2D array AND index has been
            provided as integer (avoids shapes like (1,2,2). It also may require 2D to 3D conversion in case only one
            band has been requested and the 3rd dimension has been provided as a slice.

            NOTE: -> numpy also returns a 2D array in that case
            NOTE: if array is indexed with a slice -> keep it a 3D array
            """
907
908
909
910
            # a single value -> return as float/int
            if arr.ndim == 2 and arr.size == 1:
                arr = arr[0, 0]

911
912
913
914
915
916
            # 2D -> 3D
            if arr.ndim == 2 and isinstance(getitem_params, (tuple, list)) and len(getitem_params) == 3 and \
                    isinstance(getitem_params[2], slice):
                arr = arr[:, :, np.newaxis]

            # 3D -> 2D
Daniel Scheffler's avatar
Bugfix.    
Daniel Scheffler committed
917
            if 1 in arr.shape and len(getitem_params) != 1:
918
919
                outshape = []
                for i, sh in enumerate(arr.shape):
920
                    if sh == 1 and isinstance(getitem_params[i], (int, np.integer, float, np.floating)):
921
922
923
924
925
926
927
928
                        pass
                    else:
                        outshape.append(sh)

                arr = arr.reshape(*outshape)

            return arr

929
        # check if the requested array position is already in cache -> if yes, return it from there
930
        if self._arr_cache is not None and self._arr_cache['pos'] == arr_pos:
931
            out_arr = self._arr_cache['arr_cached']
932
            out_arr = _ensure_np_shape_consistency_3D_2D(out_arr)
933
934
935
936
937
938
939

        else:
            # TODO insert a multiprocessing.Lock here in order to prevent IO bottlenecks?
            # read subset area from disk
            if bL == list(range(0, B)):
                tempArr = gdalnumeric.LoadFile(path, cS, rS, cE - cS + 1, rE - rS + 1)
                out_arr = np.swapaxes(np.swapaxes(tempArr, 0, 2), 0, 1) if B > 1 else tempArr
940
941
                if out_arr is None:
                    raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
942
943
944
945
            else:
                ds = gdal.Open(path)
                if len(bL) == 1:
                    band = ds.GetRasterBand(bL[0] + 1)
946
                    out_arr = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
947
948
                    if out_arr is None:
                        raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
949
                    del band
950
951
952
953
954
                else:
                    out_arr = np.empty((rE - rS + 1, cE - cS + 1, len(bL)))
                    for i, bIdx in enumerate(bL):
                        band = ds.GetRasterBand(bIdx + 1)
                        out_arr[:, :, i] = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
955
956
                        if out_arr is None:
                            raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
957
                        del band
958

959
                del ds
960

961
            out_arr = _ensure_np_shape_consistency_3D_2D(out_arr)
962

963
            # only set self.arr if the whole cube has been read (in order to avoid sudden shape changes)
964
            if out_arr.shape == self.shape:
965
966
967
968
969
                self.arr = out_arr

            # write _arr_cache
            self._arr_cache = dict(pos=arr_pos, arr_cached=out_arr)

970
        return out_arr  # TODO implement check of returned datatype (e.g. NoDataMask should always return bool
971
        # TODO -> would be np.int8 if an int8 file is read from disk
972
973
974
975
976
977
978

    def save(self, out_path, fmt='ENVI', creationOptions=None):
        # type: (str, str, list) -> None
        """Write the raster data to disk.

        :param out_path:        <str> output path
        :param fmt:             <str> the output format / GDAL driver code to be used for output creation, e.g. 'ENVI'
Daniel Scheffler's avatar
Daniel Scheffler committed
979
980
                                Refer to https://gdal.org/drivers/raster/index.html to get a full list of supported
                                formats.
981
982
        :param creationOptions: <list> GDAL creation options,
                                e.g., ["QUALITY=80", "REVERSIBLE=YES", "WRITE_METADATA=YES"]
983
984
        """
        if not self.q:
985
986
            print('Writing GeoArray of size %s to %s.' % (self.shape, out_path))
        assert self.ndim in [2, 3], 'Only 2D- or 3D arrays are supported.'
987
988
989
990
991
992
993
994
995

        driver = gdal.GetDriverByName(fmt)
        if driver is None:
            raise Exception("'%s' is not a supported GDAL driver. Refer to www.gdal.org/formats_list.html for full "
                            "list of GDAL driver codes." % fmt)

        if not os.path.isdir(os.path.dirname(out_path)):
            os.makedirs(os.path.dirname(out_path))

996
997
        envi_metadict = self.metadata.to_ENVI_metadict()

998
999
1000
        #####################
        # write raster data #
        #####################
For faster browsing, not all history is shown. View entire blame