baseclasses.py 84.2 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# geoarray, A fast Python interface for image geodata - either on disk or in memory.
#
# Copyright (C) 2019  Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de)
#
# This software was developed within the context of the GeoMultiSens project funded
# by the German Federal Ministry of Education and Research
# (project grant code: 01 IS 14 010 A-C).
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.


25
26
import os
import warnings
27
from pkgutil import find_loader
28
from collections import OrderedDict
29
from copy import copy, deepcopy
30
from typing import Union  # noqa F401
31
32

import numpy as np
33
from osgeo import gdal, gdal_array, gdalnumeric
34
35
36
37
from shapely.geometry import Polygon
from shapely.wkt import loads as shply_loads
# dill -> imported when dumping GeoArray

38
39
40
41
from py_tools_ds.convenience.object_oriented import alias_property
from py_tools_ds.geo.coord_calc import get_corner_coordinates
from py_tools_ds.geo.coord_grid import snap_bounds_to_pixGrid
from py_tools_ds.geo.coord_trafo import mapXY2imXY, imXY2mapXY, transform_any_prj, reproject_shapelyGeometry
42
from py_tools_ds.geo.projection import prj_equal, WKT2EPSG, EPSG2WKT, isLocal, CRS
43
44
from py_tools_ds.geo.raster.conversion import raster2polygon
from py_tools_ds.geo.vector.topology \
45
    import get_footprint_polygon, polyVertices_outside_poly, fill_holes_within_poly
46
47
from py_tools_ds.geo.vector.geometry import boxObj
from py_tools_ds.io.raster.gdal import get_GDAL_ds_inmem
48
from py_tools_ds.numeric.numbers import is_number
49
from py_tools_ds.numeric.array import get_array_tilebounds
50
51
52

#  internal imports
from .subsetting import get_array_at_mapPos
53
from .metadata import GDAL_Metadata
54
55

__author__ = 'Daniel Scheffler'
56
57
58


class GeoArray(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
59
60
61
62
63
64
65
    """
    This class creates a fast Python interface for geodata - either on disk or in memory. It can be instanced
    with a file path or with a numpy array and the corresponding geoinformation. Instances can always be indexed
    like normal numpy arrays, no matter if GeoArray has been instanced from file or from an in-memory array.
    GeoArray provides a wide range of geo-related attributes belonging to the dataset as well as some functions for
    quickly visualizing the data as a map, a simple image or an interactive image.
    """
66
67
    def __init__(self, path_or_array, geotransform=None, projection=None, bandnames=None, nodata=None, progress=True,
                 q=False):
68
        # type: (Union[str, np.ndarray, GeoArray], tuple, str, list, float, bool, bool) -> None
Daniel Scheffler's avatar
Daniel Scheffler committed
69
        """Get an instance of GeoArray.
70
71
72
73
74
75
76
77
78
79
80
81
82

        :param path_or_array:   a numpy.ndarray or a valid file path
        :param geotransform:    GDAL geotransform of the given array or file on disk
        :param projection:      projection of the given array or file on disk as WKT string
                                (only needed if GeoArray is instanced with an array)
        :param bandnames:       names of the bands within the input array, e.g. ['mask_1bit', 'mask_clouds'],
                                (default: ['B1', 'B2', 'B3', ...])
        :param nodata:          nodata value
        :param progress:        show progress bars (default: True)
        :param q:               quiet mode (default: False)
        """
        # TODO implement compatibility to GDAL VRTs
        if not (isinstance(path_or_array, (str, np.ndarray, GeoArray)) or
83
           issubclass(getattr(path_or_array, '__class__'), GeoArray)):
84
            raise ValueError("%s parameter 'arg' takes only string, np.ndarray or GeoArray(and subclass) instances. "
85
                             "Got %s." % (self.__class__.__name__, type(path_or_array)))
86
87

        if path_or_array is None:
88
            raise ValueError("The %s parameter 'path_or_array' must not be None!" % self.__class__.__name__)
89
90
91
92
93

        if isinstance(path_or_array, str):
            assert ' ' not in path_or_array, "The given path contains whitespaces. This is not supported by GDAL."

            if not os.path.exists(path_or_array):
94
                raise FileNotFoundError(path_or_array)
95

96
97
        if isinstance(path_or_array, GeoArray) or issubclass(getattr(path_or_array, '__class__'), GeoArray):
            self.__dict__ = path_or_array.__dict__.copy()
98
            self._initParams = dict([x for x in locals().items() if x[0] != "self"])
99
100
            self.geotransform = geotransform or self.geotransform
            self.projection = projection or self.projection
101
            self.bandnames = bandnames or list(self.bandnames.keys())
102
103
104
            self._nodata = nodata if nodata is not None else self._nodata
            self.progress = False if progress is False else self.progress
            self.q = q if q is not None else self.q
105
106

        else:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
            self._initParams = dict([x for x in locals().items() if x[0] != "self"])
            self.arg = path_or_array
            self._arr = path_or_array if isinstance(path_or_array, np.ndarray) else None
            self.filePath = path_or_array if isinstance(path_or_array, str) and path_or_array else None
            self.basename = os.path.splitext(os.path.basename(self.filePath))[0] if not self.is_inmem else 'IN_MEM'
            self.progress = progress
            self.q = q
            self._arr_cache = None  # dict containing key 'pos' and 'arr_cached'
            self._geotransform = None
            self._projection = None
            self._shape = None
            self._dtype = None
            self._nodata = nodata
            self._mask_nodata = None
            self._mask_baddata = None
122
123
            self._footprint_poly = None
            self._gdalDataset_meta_already_set = False
124
125
            self._metadata = None
            self._bandnames = None
126
127

            if bandnames:
128
                self.bandnames = bandnames  # use property in order to validate given value
129
            if geotransform:
130
                self.geotransform = geotransform  # use property in order to validate given value
131
            if projection:
132
                self.projection = projection  # use property in order to validate given value
133
134
135
136

            if self.filePath:
                self.set_gdalDataset_meta()

137
138
139
140
141
            if 'nodata' in self._initParams and self._initParams['nodata'] is not None:
                self._validate_nodataVal()

    def _validate_nodataVal(self):
        """Check if a given nodata value is within the valid value range of the data type."""
142
143
144
145
146
147
148
149
150
151
        _nodata = self._initParams['nodata']

        if np.issubdtype(self.dtype, np.integer):
            dt_min, dt_max = np.iinfo(self.dtype).min, np.iinfo(self.dtype).max
        elif np.issubdtype(self.dtype, np.floating):
            dt_min, dt_max = np.finfo(self.dtype).min, np.finfo(self.dtype).max
        else:
            return

        if not dt_min <= _nodata <= dt_max:
152
            raise ValueError("The given no-data value (%s) is out range for data type %s."
153
                             % (self._initParams['nodata'], str(np.dtype(self.dtype))))
154

155
156
157
158
159
160
    @property
    def arr(self):
        return self._arr

    @arr.setter
    def arr(self, ndarray):
161
162
        assert isinstance(ndarray, np.ndarray), "'arr' can only be set to a numpy array! Got %s." % type(ndarray)
        # assert ndarray.shape == self.shape, "'arr' can only be set to a numpy array with shape %s. Received %s. " \
163
164
165
166
167
        #                                    "If you need to change the dimensions, create a new instance of %s." \
        #                                    %(self.shape, ndarray.shape, self.__class__.__name__)
        #  THIS would avoid warping like this: geoArr.arr, geoArr.gt, geoArr.prj = warp(...)

        if ndarray.shape != self.shape:
168
            self.flush_cache()  # the cached array is not useful anymore
169
170
171
172
173

        self._arr = ndarray

    @property
    def bandnames(self):
174
        if self._bandnames and len(self._bandnames) == self.bands:
175
176
            return self._bandnames
        else:
177
            del self.bandnames  # runs deleter which sets it to default values
178
179
180
181
182
183
184
            return self._bandnames

    @bandnames.setter
    def bandnames(self, list_bandnames):
        # type: (list) -> None

        if list_bandnames:
185
186
187
188
189
190
191
192
            if not isinstance(list_bandnames, list):
                raise TypeError("A list must be given when setting the 'bandnames' attribute. "
                                "Received %s." % type(list_bandnames))
            if len(list_bandnames) != self.bands:
                raise ValueError('Number of given bandnames does not match number of bands in array.')
            if len(list(set([type(b) for b in list_bandnames]))) != 1 or not isinstance(list_bandnames[0], str):
                raise ValueError("'bandnames must be a set of strings. Got other datatypes in there.'")

193
            bN_dict = OrderedDict((band, i) for i, band in enumerate(list_bandnames))
194
195

            if len(bN_dict) != self.bands:
196
                raise ValueError('Bands must have unique names. Received band list: %s' % list_bandnames)
197
198
199

            self._bandnames = bN_dict

200
            try:
201
                self.metadata.band_meta['band_names'] = list_bandnames
202
203
204
            except AttributeError:
                # in case self._metadata is None
                pass
205
206
207
208
209
210
211
        else:
            del self.bandnames

    @bandnames.deleter
    def bandnames(self):
        self._bandnames = OrderedDict(('B%s' % band, i) for i, band in enumerate(range(1, self.bands + 1)))
        if self._metadata is not None:
212
            self.metadata.band_meta['band_names'] = list(self._bandnames.keys())
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    @property
    def is_inmem(self):
        """Check if associated image array is completely loaded into memory."""
        return isinstance(self.arr, np.ndarray)

    @property
    def shape(self):
        """Get the array shape of the associated image array."""
        if self.is_inmem:
            return self.arr.shape
        else:
            if self._shape:
                return self._shape
            else:
                self.set_gdalDataset_meta()
                return self._shape

    @property
    def ndim(self):
        """Get the number dimensions of the associated image array."""
        return len(self.shape)

    @property
    def rows(self):
        """Get the number of rows of the associated image array."""
        return self.shape[0]

    @property
    def columns(self):
        """Get the number of columns of the associated image array."""

        return self.shape[1]

    cols = alias_property('columns')

    @property
    def bands(self):
        """Get the number of bands of the associated image array."""
252
        return self.shape[2] if len(self.shape) > 2 else 1
253
254
255
256
257
258
259
260
261
262
263
264
265
266

    @property
    def dtype(self):
        """Get the numpy data type of the associated image array."""
        if self._dtype:
            return self._dtype
        elif self.is_inmem:
            return self.arr.dtype
        else:
            self.set_gdalDataset_meta()
            return self._dtype

    @property
    def geotransform(self):
267
        """Get the GDAL GeoTransform of the associated image, e.g., (283500.0, 5.0, 0.0, 4464500.0, 0.0, -5.0)"""
268
269
270
271
272
273
        if self._geotransform:
            return self._geotransform
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
            return self._geotransform
        else:
274
            return [0, 1, 0, 0, 0, -1]
275
276
277

    @geotransform.setter
    def geotransform(self, gt):
278
        # type: (Union[list, tuple]) -> None
279
280
        assert isinstance(gt, (list, tuple)) and len(gt) == 6,\
            'geotransform must be a list with 6 numbers. Got %s.' % str(gt)
281

282
        for i in gt:
283
            assert is_number(i), "geotransform must contain only numbers. Got '%s' (type: %s)." % (i, type(i))
284

285
        self._geotransform = gt
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

    gt = alias_property('geotransform')

    @property
    def xgsd(self):
        """Get the X resolution in units of the given or detected projection."""
        return self.geotransform[1]

    @property
    def ygsd(self):
        """Get the Y resolution in units of the given or detected projection."""
        return abs(self.geotransform[5])

    @property
    def xygrid_specs(self):
        """
        Get the specifications for the X/Y coordinate grid, e.g. [[15,30], [0,30]] for a coordinate with its origin
        at X/Y[15,0] and a GSD of X/Y[15,30].
        """

306
        def get_grid(gt, xgsd, ygsd): return [[gt[0], gt[0] + xgsd], [gt[3], gt[3] - ygsd]]
307
308
309
310
311
312
313
314
315
316
317
318
319
        return get_grid(self.geotransform, self.xgsd, self.ygsd)

    @property
    def projection(self):
        """
        Get the projection of the associated image. Setting the projection is only allowed if GeoArray has been
        instanced from memory or the associated file on disk has no projection.
        """

        if self._projection:
            return self._projection
        elif not self.is_inmem:
            self.set_gdalDataset_meta()
320
            return self._projection  # or "LOCAL_CS[\"MAP\"]"
321
        else:
322
            return ''  # '"LOCAL_CS[\"MAP\"]"
323
324
325

    @projection.setter
    def projection(self, prj):
326
        # type: (str) -> None
327
        self._projection = prj
328
329
330
331
332

    prj = alias_property('projection')

    @property
    def epsg(self):
333
        # type: () -> int
334
335
336
337
338
        """Get the EPSG code of the projection of the GeoArray."""
        return WKT2EPSG(self.projection)

    @epsg.setter
    def epsg(self, epsg_code):
339
        # type: (int) -> None
340
341
342
343
344
345
346
        self.projection = EPSG2WKT(epsg_code)

    @property
    def box(self):
        mapPoly = get_footprint_polygon(get_corner_coordinates(gt=self.geotransform, cols=self.columns, rows=self.rows))
        return boxObj(gt=self.geotransform, prj=self.projection, mapPoly=mapPoly)

347
348
349
350
    @property
    def is_map_geo(self):
        # type: () -> bool
        """
351
        Return 'True' if the GeoArray instance has a valid geoinformation with map instead of image coordinates.
352
        """
353
        return all([self.gt, list(self.gt) != [0, 1, 0, 0, 0, -1], self.prj])
354

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    @property
    def nodata(self):
        """
        Get the nodata value of the GeoArray. If GeoArray has been instanced with a file path the file is checked
        for an existing nodata value. Otherwise (if no value is exlicitly given during object instanciation) the nodata
        value is tried to be automatically detected.
        """

        if self._nodata is not None:
            return self._nodata
        else:
            # try to get nodata value from file
            if not self.is_inmem:
                self.set_gdalDataset_meta()
            if self._nodata is None:
370
                self.find_noDataVal()
371
372
373
374
375
376
                if self._nodata == 'ambiguous':
                    warnings.warn('Nodata value could not be clearly identified. It has been set to None.')
                    self._nodata = None
                else:
                    if self._nodata is not None and not self.q:
                        print("Automatically detected nodata value for %s '%s': %s"
377
                              % (self.__class__.__name__, self.basename, self._nodata))
378
379
380
381
            return self._nodata

    @nodata.setter
    def nodata(self, value):
382
        # type: (Union[int, None]) -> None
383
384
        self._nodata = value

385
386
387
        if self._metadata and value is not None:
            self.metadata.global_meta.update({'data ignore value': str(value)})

388
389
390
391
392
393
394
395
396
    @property
    def mask_nodata(self):
        """
        Get the nodata mask of the associated image array. It is calculated using all image bands.
        """

        if self._mask_nodata is not None:
            return self._mask_nodata
        else:
397
            self.calc_mask_nodata()  # sets self._mask_nodata
398
399
400
401
402
403
404
405
406
407
408
            return self._mask_nodata

    @mask_nodata.setter
    def mask_nodata(self, mask):
        """Set bad data mask.

        :param mask:    Can be a file path, a numpy array or an instance o GeoArray.
        """

        if mask is not None:
            from .masks import NoDataMask
409
410
            geoArr_mask = NoDataMask(mask, progress=self.progress, q=self.q)
            geoArr_mask.gt = geoArr_mask.gt if geoArr_mask.gt not in [None, [0, 1, 0, 0, 0, -1]] else self.gt
411
            geoArr_mask.prj = geoArr_mask.prj if geoArr_mask.prj else self.prj
412
            imName = "the %s '%s'" % (self.__class__.__name__, self.basename)
413
414
415
416

            assert geoArr_mask.bands == 1, \
                'Expected one single band as nodata mask for %s. Got %s bands.' % (self.basename, geoArr_mask.bands)
            assert geoArr_mask.shape[:2] == self.shape[:2], 'The provided nodata mask must have the same number of ' \
417
                                                            'rows and columns as the %s itself.' % imName
418
419
            assert geoArr_mask.gt == self.gt, \
                'The geotransform of the given nodata mask for %s must match the geotransform of the %s itself. ' \
420
                'Got %s.' % (imName, self.__class__.__name__, geoArr_mask.gt)
421
422
            assert not geoArr_mask.prj or prj_equal(geoArr_mask.prj, self.prj), \
                'The projection of the given nodata mask for the %s must match the projection of the %s itself.' \
423
                % (imName, self.__class__.__name__)
424
425

            self._mask_nodata = geoArr_mask
426
427
428
429
430
431
        else:
            del self.mask_nodata

    @mask_nodata.deleter
    def mask_nodata(self):
        self._mask_nodata = None
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

    @property
    def mask_baddata(self):
        """
        Returns the bad data mask for the associated image array if it has been explicitly previously. It can be set
         by passing a file path, a numpy array or an instance of GeoArray to the setter of this property.
        """

        return self._mask_baddata

    @mask_baddata.setter
    def mask_baddata(self, mask):
        """Set bad data mask.

        :param mask:    Can be a file path, a numpy array or an instance o GeoArray.
        """

        if mask is not None:
            from .masks import BadDataMask
451
452
            geoArr_mask = BadDataMask(mask, progress=self.progress, q=self.q)
            geoArr_mask.gt = geoArr_mask.gt if geoArr_mask.gt not in [None, [0, 1, 0, 0, 0, -1]] else self.gt
453
            geoArr_mask.prj = geoArr_mask.prj if geoArr_mask.prj else self.prj
454
            imName = "the %s '%s'" % (self.__class__.__name__, self.basename)
455
456
457
458

            assert geoArr_mask.bands == 1, \
                'Expected one single band as bad data mask for %s. Got %s bands.' % (self.basename, geoArr_mask.bands)
            assert geoArr_mask.shape[:2] == self.shape[:2], 'The provided bad data mask must have the same number of ' \
459
                                                            'rows and columns as the %s itself.' % imName
460
461
            assert geoArr_mask.gt == self.gt, \
                'The geotransform of the given bad data mask for %s must match the geotransform of the %s itself. ' \
462
                'Got %s.' % (imName, self.__class__.__name__, geoArr_mask.gt)
463
464
            assert prj_equal(geoArr_mask.prj, self.prj), \
                'The projection of the given bad data mask for the %s must match the projection of the %s itself.' \
465
                % (imName, self.__class__.__name__)
466
467

            self._mask_baddata = geoArr_mask
468
469
470
471
472
473
        else:
            del self.mask_baddata

    @mask_baddata.deleter
    def mask_baddata(self):
        self._mask_baddata = None
474
475
476
477
478
479
480
481
482
483

    @property
    def footprint_poly(self):
        # FIXME should return polygon in image coordinates if no projection is available
        """
        Get the footprint polygon of the associated image array (returns an instance of shapely.geometry.Polygon.
        """

        if self._footprint_poly is None:
            assert self.mask_nodata is not None, 'A nodata mask is needed for calculating the footprint polygon. '
484
            if False not in self.mask_nodata[:]:
485
486
487
488
                # do not run raster2polygon if whole image is filled with data
                self._footprint_poly = self.box.mapPoly
            else:
                try:
489
                    multipolygon = raster2polygon(self.mask_nodata.astype(np.uint8), self.gt, self.prj, exact=False,
490
                                                  progress=self.progress, q=self.q, maxfeatCount=10, timeout=5)
491
                    self._footprint_poly = fill_holes_within_poly(multipolygon)
492
                except (RuntimeError, TimeoutError):
493
494
495
496
                    if not self.q:
                        warnings.warn("\nCalculation of footprint polygon failed for %s '%s'. Using outer bounds. One "
                                      "reason could be that the nodata value appears within the actual image (not only "
                                      "as fill value). To avoid this use another nodata value. Current nodata value is "
497
                                      "%s." % (self.__class__.__name__, self.basename, self.nodata))
498
499
500
                    self._footprint_poly = self.box.mapPoly

            # validation
501
            assert not polyVertices_outside_poly(self._footprint_poly, self.box.mapPoly, tolerance=1e-5), \
502
503
504
                "Computing footprint polygon for %s '%s' failed. The resulting polygon is partly or completely " \
                "outside of the image bounds." % (self.__class__.__name__, self.basename)
            # assert self._footprint_poly
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
            # for XY in self.corner_coord:
            #    assert self.GeoArray.box.mapPoly.contains(Point(XY)) or self.GeoArray.box.mapPoly.touches(Point(XY)), \
            #        "The corner position '%s' is outside of the %s." % (XY, self.imName)

        return self._footprint_poly

    @footprint_poly.setter
    def footprint_poly(self, poly):
        if isinstance(poly, Polygon):
            self._footprint_poly = poly
        elif isinstance(poly, str):
            self._footprint_poly = shply_loads(poly)
        else:
            raise ValueError("'footprint_poly' can only be set from a shapely polygon or a WKT string.")

    @property
    def metadata(self):
        """
523
        Returns a DataFrame containing all available metadata (read from file if available).
524
525
526
527
        Use 'metadata[band_index].to_dict()' to get a metadata dictionary for a specific band.
        Use 'metadata.loc[row_name].to_dict()' to get all metadata values of the same key for all bands as dictionary.
        Use 'metadata.loc[row_name, band_index] = value' to set a new value.

528
        :return:  pandas.DataFrame
529
530
531
532
533
        """

        if self._metadata is not None:
            return self._metadata
        else:
534
            default = GDAL_Metadata(nbands=self.bands, nodata_allbands=self._nodata)
535

536
537
538
539
540
541
542
543
            self._metadata = default
            if not self.is_inmem:
                self.set_gdalDataset_meta()
                return self._metadata
            else:
                return self._metadata

    @metadata.setter
544
545
546
547
548
549
    def metadata(self, meta):
        if not isinstance(meta, GDAL_Metadata) or meta.bands != self.bands:
            raise ValueError("%s.metadata can only be set with an instance of geoarray.metadata.GDAL_Metadata of "
                             "which the band number corresponds to the band number of %s."
                             % (self.__class__.__name__, self.__class__.__name__))
        self._metadata = meta
550
551
552
553

    meta = alias_property('metadata')

    def __getitem__(self, given):
554
        if isinstance(given, (int, float, slice, np.integer, np.floating)) and self.ndim == 3:
555
556
557
558
559
560
561
562
563
564
565
566
567
            # handle 'given' as index for 3rd (bands) dimension
            if self.is_inmem:
                return self.arr[:, :, given]
            else:
                return self.from_path(self.arg, [given])

        elif isinstance(given, str):
            # behave like a dictionary and return the corresponding band
            if self.bandnames:
                if given not in self.bandnames:
                    raise ValueError("'%s' is not a known band. Known bands are: %s"
                                     % (given, ', '.join(list(self.bandnames.keys()))))
                if self.is_inmem:
568
                    return self.arr if self.ndim == 2 else self.arr[:, :, self.bandnames[given]]
569
570
571
572
                else:
                    return self.from_path(self.arg, [self.bandnames[given]])
            else:
                raise ValueError('String indices are only supported if %s has been instanced with bandnames given.'
573
                                 % self.__class__.__name__)
574
575
576
577

        elif isinstance(given, (tuple, list)):
            # handle requests like geoArr[[1,2],[3,4]  -> not implemented in from_path if array is not in mem
            types = [type(i) for i in given]
Daniel Scheffler's avatar
Daniel Scheffler committed
578

579
            if list in types or tuple in types:
580
581
582
583
584
585
586
587

                # avoid that the whole cube is read if only data from a single band is requested
                if not self.is_inmem \
                   and len(given) == 3 \
                   and isinstance(given[2], (int, float, np.integer, np.floating)):
                    band_subset = GeoArray(self.filePath)[:, :, given[2]]
                    return band_subset[given[:2]]

588
589
                self.to_mem()

590
            if len(given) == 3:
591
592

                # handle strings in the 3rd dim of 'given' -> convert them to a band index
593
                if isinstance(given[2], str):
594
595
596
597
598
599
600
601
                    if self.bandnames:
                        if given[2] not in self.bandnames:
                            raise ValueError("'%s' is not a known band. Known bands are: %s"
                                             % (given[2], ', '.join(list(self.bandnames.keys()))))

                        band_idx = self.bandnames[given[2]]
                        # NOTE: the string in the 3rd is ignored if ndim==2 and band_idx==0
                        if self.is_inmem:
602
                            return self.arr if (self.ndim == 2 and band_idx == 0) else self.arr[:, :, band_idx]
603
                        else:
604
605
                            getitem_params = \
                                given[:2] if (self.ndim == 2 and band_idx == 0) else given[:2] + (band_idx,)
606
607
608
609
610
611
612
                            return self.from_path(self.arg, getitem_params)
                    else:
                        raise ValueError(
                            'String indices are only supported if %s has been instanced with bandnames given.'
                            % self.__class__.__name__)

                # in case a third dim is requested from 2D-array -> ignore 3rd dim if 3rd dim is 0
613
                elif self.ndim == 2 and given[2] == 0:
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
                    if self.is_inmem:
                        return self.arr[given[:2]]
                    else:
                        return self.from_path(self.arg, given[:2])

        # if nothing has been returned until here -> behave like a numpy array
        if self.is_inmem:
            return self.arr[given]
        else:
            getitem_params = [given] if isinstance(given, slice) else given
            return self.from_path(self.arg, getitem_params)

    def __setitem__(self, idx, array2set):
        """Overwrites the pixel values of GeoArray.arr with the given array.

        :param idx:         <int, list, slice> the index position to overwrite
        :param array2set:   <np.ndarray> array to be set. Must be compatible to the given index position.
        :return:
        """

        if self.is_inmem:
            self.arr[idx] = array2set
        else:
            raise NotImplementedError('Item assignment for %s instances that are not in memory is not yet supported.'
638
                                      % self.__class__.__name__)
639
640
641

    def __getattr__(self, attr):
        # check if the requested attribute can not be present because GeoArray has been instanced with an array
642
643
        attrsNot2Link2np = ['__deepcopy__']   # attributes we don't want to inherit from numpy.ndarray

644
645
        if attr not in self.__dir__() and not self.is_inmem and attr in ['shape', 'dtype', 'geotransform',
                                                                         'projection']:
646
647
            self.set_gdalDataset_meta()

648
649
        if attr in self.__dir__():  # __dir__() includes also methods and properties
            return self.__getattribute__(attr)  # __getattribute__ avoids infinite loop
650
        elif attr not in attrsNot2Link2np and hasattr(np.array([]), attr):
651
652
            return self[:].__getattribute__(attr)
        else:
653
            raise AttributeError("%s object has no attribute '%s'." % (self.__class__.__name__, attr))
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670

    def __getstate__(self):
        """Defines how the attributes of GMS object are pickled."""

        # clean array cache in order to avoid cache pickling
        self.flush_cache()

        return self.__dict__

    def __setstate__(self, state):
        """Defines how the attributes of GMS object are unpickled.
        NOTE: This method has been implemented because otherwise pickled and unpickled instances show recursion errors
        within __getattr__ when requesting any attribute.
        """

        self.__dict__ = state

671
672
673
674
675
676
677
678
679
    def calc_mask_nodata(self, fromBand=None, overwrite=False, flag='all'):
        # type: (int, bool, str) -> np.ndarray
        """Calculates a no data mask with values False (=nodata) and True (=data).

        :param fromBand:   index of the band to be used (if None, all bands are used)
        :param overwrite:  whether to overwrite existing nodata mask that has already been calculated
        :param flag:       algorithm how to flag pixels (default: 'all')
                           'all': flag those pixels as nodata that contain the nodata value in ALL bands
                           'any': flag those pixels as nodata that contain the nodata value in ANY band
680
681
682
        :return:
        """
        if self._mask_nodata is None or overwrite:
683
684
685
            if flag not in ['all', 'any']:
                raise ValueError(flag)

686
            assert self.ndim in [2, 3], "Only 2D or 3D arrays are supported. Got a %sD array." % self.ndim
687
            arr = self[:, :, fromBand] if self.ndim == 3 and fromBand is not None else self[:]
688

689
            if self.nodata is None:
690
                mask = np.ones((self.rows, self.cols), bool)
691
692
693
694
695
696
697
698
699
700

            elif np.isnan(self.nodata):
                nanmask = np.isnan(arr)
                nanbands = np.all(np.all(nanmask, axis=0), axis=0)

                if np.all(nanbands):
                    mask = np.full(arr.shape[:2], False)
                elif arr.ndim == 2:
                    mask = ~np.isnan(arr)
                else:
701
702
703
704
705
706
707
708
709
710
711
712
713
                    arr_1st_databand = arr[:, :, np.argwhere(~nanbands)[0][0]]
                    arr_remain = arr[:, :, ~nanbands][:, :, 1:]

                    mask = ~np.isnan(arr_1st_databand)  # True where 1st data band has data

                    if flag == 'all':
                        # ALL bands need to contain np.nan to flag the mask as nodata
                        # overwrite the mask at nodata positions (False) with True in case there is data in ANY band
                        mask[~mask] = np.any(~np.isnan(arr_remain[~mask]), axis=1)
                    else:
                        # ANY band needs to contain np.nan to flag the mask as nodata
                        # overwrite the mask at data positions (True) with False in case there is np.nan in ANY band
                        mask[mask] = ~np.any(np.isnan(arr_remain[mask]), axis=1)
714

715
            else:
716
                bandmeans = np.mean(np.mean(arr, axis=0), axis=0)
717
                nodatabands = bandmeans == self.nodata
718
719
720
721
722

                if np.nanmean(bandmeans) == self.nodata:
                    mask = np.full(arr.shape[:2], False)
                elif arr.ndim == 2:
                    mask = arr != self.nodata
723
                else:
724
725
726
727
728
729
730
731
732
733
734
735
736
                    arr_1st_databand = arr[:, :, np.argwhere(~nodatabands)[0][0]]
                    arr_remain = arr[:, :, ~nodatabands][:, :, 1:]

                    mask = np.array(arr_1st_databand != self.nodata)  # True where 1st data band has data

                    if flag == 'all':
                        # ALL bands need to contain nodata to flag the mask as such
                        # overwrite the mask at nodata positions (False) with True in case there is data in ANY band
                        mask[~mask] = np.any(arr_remain[~mask] != self.nodata, axis=1)
                    else:
                        # ANY band needs to contain nodata to flag the mask as such
                        # overwrite the mask at data positions (True) with False in case there is nodata in ANY band
                        mask[mask] = ~np.any(arr_remain[mask] == self.nodata, axis=1)
737
738
739
740

            self.mask_nodata = mask

            return mask
741

742
743
744
745
746
747
748
749
    def find_noDataVal(self, bandIdx=0, sz=3):
        """Tries to derive no data value from homogenious corner pixels within 3x3 windows (by default).
        :param bandIdx:
        :param sz: window size in which corner pixels are analysed
        """
        wins = [self[0:sz, 0:sz, bandIdx], self[0:sz, -sz:, bandIdx],
                self[-sz:, -sz:, bandIdx], self[-sz:, 0:sz, bandIdx]]  # UL, UR, LR, LL

750
751
        means, stds = [np.mean(win) for win in wins], [np.std(win) for win in wins]
        possVals = [mean for mean, std in zip(means, stds) if std == 0 or np.isnan(std)]
752
753
754
755
        # possVals==[]: all corners are filled with data; np.std(possVals)==0: noDataVal clearly identified

        if possVals:
            if np.std(possVals) != 0:
756
757
758
759
760
761
                if np.isnan(np.std(possVals)):
                    # at least one of the possible values is np.nan
                    nodata = np.nan
                else:
                    # different possible nodata values have been found in the image corner
                    nodata = 'ambiguous'
762
763
764
765
766
            else:
                if len(possVals) <= 2:
                    # each window in each corner
                    warnings.warn("\nAutomatic nodata value detection returned the value %s for GeoArray '%s' but this "
                                  "seems to be unreliable (occurs in only %s). To avoid automatic detection, just pass "
767
768
769
                                  "the correct nodata value."
                                  % (possVals[0], self.basename, ('2 image corners' if len(possVals) == 2 else
                                                                  '1 image corner')))
770
                nodata = possVals[0]
771
        else:
772
773
            nodata = None

774
        self.nodata = nodata
775
        return nodata
776

777
778
779
780
781
782
783
784
785
786
    def set_gdalDataset_meta(self):
        """Retrieves GDAL metadata from file. This function is only executed once to avoid overwriting of user defined
         attributes, that are defined after object instanciation.

        :return:
        """

        if not self._gdalDataset_meta_already_set:
            assert self.filePath
            ds = gdal.Open(self.filePath)
787
788
789
            if not ds:
                raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())

790
            # set private class variables (in order to avoid recursion error)
791
792
            self._shape = tuple([ds.RasterYSize, ds.RasterXSize] + ([ds.RasterCount] if ds.RasterCount > 1 else []))
            self._dtype = gdal_array.GDALTypeCodeToNumericTypeCode(ds.GetRasterBand(1).DataType)
793
            self._geotransform = list(ds.GetGeoTransform())
794
795

            # for some reason GDAL reads arbitrary geotransforms as (0, 1, 0, 0, 0, 1) instead of (0, 1, 0, 0, 0, -1)
796
            self._geotransform[5] = -abs(self._geotransform[5])  # => force ygsd to be negative
797

798
            # consequently use WKT1 strings here as GDAL always exports transformation results as WKT1
799
            wkt = ds.GetProjection()
800
            self._projection = CRS(wkt).to_wkt(version="WKT1_GDAL") if not isLocal(wkt) else ''
801

802
803
804
            if 'nodata' not in self._initParams or self._initParams['nodata'] is None:
                band = ds.GetRasterBand(1)
                # FIXME this does not support different nodata values within the same file
805
                self.nodata = band.GetNoDataValue()
806

807
808
809
            # set metadata attribute
            if self.is_inmem or not self.filePath:
                # metadata cannot be read from disk -> set it to the default
810
                self._metadata = GDAL_Metadata(nbands=self.bands, nodata_allbands=self._nodata)
811

812
813
            else:
                self._metadata = GDAL_Metadata(filePath=self.filePath)
814

815
816
817
818
819
820
            # copy over the band names
            if 'band_names' in self.metadata.band_meta and self.metadata.band_meta['band_names']:
                self.bandnames = self.metadata.band_meta['band_names']

            # noinspection PyUnusedLocal
            ds = None
821
822
823
824
825
826
827
828
829
830
831
832
833

        self._gdalDataset_meta_already_set = True

    def from_path(self, path, getitem_params=None):
        # type: (str, list) -> np.ndarray
        """Read a GDAL compatible raster image from disk, with respect to the given image position.
        NOTE: If the requested array position is already in cache, it is returned from there.

        :param path:            <str> the file path of the image to read
        :param getitem_params:  <list> a list of slices in the form [row_slice, col_slice, band_slice]
        :return out_arr:        <np.ndarray> the output array
        """
        ds = gdal.Open(path)
834
835
836
        if not ds:
            raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())

837
        R, C, B = ds.RasterYSize, ds.RasterXSize, ds.RasterCount
838
        del ds
839

840
        # convert getitem_params to subset area to be read #
841
842
843
844
        rS, rE, cS, cE, bS, bE, bL = [None] * 7

        # populate rS, rE, cS, cE, bS, bE, bL
        if getitem_params:
845
            # populate rS, rE, cS, cE
846
847
848
849
850
            if len(getitem_params) >= 2:
                givenR, givenC = getitem_params[:2]
                if isinstance(givenR, slice):
                    rS = givenR.start
                    rE = givenR.stop - 1 if givenR.stop is not None else None
851
                elif isinstance(givenR, (int, np.integer)):
852
853
854
855
856
                    rS = givenR
                    rE = givenR
                if isinstance(givenC, slice):
                    cS = givenC.start
                    cE = givenC.stop - 1 if givenC.stop is not None else None
857
                elif isinstance(givenC, (int, np.integer)):
858
859
                    cS = givenC
                    cE = givenC
860
861

            # populate bS, bE, bL
862
863
864
865
866
            if len(getitem_params) in [1, 3]:
                givenB = getitem_params[2] if len(getitem_params) == 3 else getitem_params[0]
                if isinstance(givenB, slice):
                    bS = givenB.start
                    bE = givenB.stop - 1 if givenB.stop is not None else None
867
                elif isinstance(givenB, (int, np.integer)):
868
869
                    bS = givenB
                    bE = givenB
870
                elif isinstance(givenB, (tuple, list)):
871
872
873
                    typesInGivenB = [type(i) for i in givenB]
                    assert len(list(set(typesInGivenB))) == 1, \
                        'Mixed data types within the list of bands are not supported.'
874
                    if isinstance(givenB[0], (int, np.integer)):
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
                        bL = list(givenB)
                    elif isinstance(givenB[0], str):
                        bL = [self.bandnames[i] for i in givenB]
                elif type(givenB) in [str]:
                    bL = [self.bandnames[givenB]]

        # set defaults for not given values
        rS = rS if rS is not None else 0
        rE = rE if rE is not None else R - 1
        cS = cS if cS is not None else 0
        cE = cE if cE is not None else C - 1
        bS = bS if bS is not None else 0
        bE = bE if bE is not None else B - 1
        bL = list(range(bS, bE + 1)) if not bL else bL

        # convert negative to positive ones
        rS = rS if rS >= 0 else self.rows + rS
        rE = rE if rE >= 0 else self.rows + rE
        cS = cS if cS >= 0 else self.columns + cS
        cE = cE if cE >= 0 else self.columns + cE
        bS = bS if bS >= 0 else self.bands + bS
        bE = bE if bE >= 0 else self.bands + bE
897
        bL = [b if b >= 0 else (self.bands + b) for b in bL]
898
899

        # validate subset area bounds to be read
900
901
902
903
904
905
906
        def msg(v, idx, sz):
            # FIXME numpy raises that error ONLY for the 2nd axis
            return '%s is out of bounds for axis %s with size %s' % (v, idx, sz)

        for val, axIdx, axSize in zip([rS, rE, cS, cE, bS, bE], [0, 0, 1, 1, 2, 2], [R, R, C, C, B, B]):
            if not 0 <= val <= axSize - 1:
                raise ValueError(msg(val, axIdx, axSize))
907
908

        # summarize requested array position in arr_pos
909
        # NOTE: # bandlist must be string because truth value of an array with more than one element is ambiguous
910
911
        arr_pos = dict(rS=rS, rE=rE, cS=cS, cE=cE, bS=bS, bE=bE, bL=bL)

912
913
        def _ensure_np_shape_consistency_3D_2D(arr):
            """Ensure numpy output shape consistency according to the given indexing parameters.
914
915
916
917
918
919
920
921

            This may require 3D to 2D conversion in case out_arr can be represented by a 2D array AND index has been
            provided as integer (avoids shapes like (1,2,2). It also may require 2D to 3D conversion in case only one
            band has been requested and the 3rd dimension has been provided as a slice.

            NOTE: -> numpy also returns a 2D array in that case
            NOTE: if array is indexed with a slice -> keep it a 3D array
            """
922
923
924
925
            # a single value -> return as float/int
            if arr.ndim == 2 and arr.size == 1:
                arr = arr[0, 0]

926
927
928
929
930
931
            # 2D -> 3D
            if arr.ndim == 2 and isinstance(getitem_params, (tuple, list)) and len(getitem_params) == 3 and \
                    isinstance(getitem_params[2], slice):
                arr = arr[:, :, np.newaxis]

            # 3D -> 2D
Daniel Scheffler's avatar
Bugfix.    
Daniel Scheffler committed
932
            if 1 in arr.shape and len(getitem_params) != 1:
933
934
                outshape = []
                for i, sh in enumerate(arr.shape):
935
                    if sh == 1 and isinstance(getitem_params[i], (int, np.integer, float, np.floating)):
936
937
938
939
940
941
942
943
                        pass
                    else:
                        outshape.append(sh)

                arr = arr.reshape(*outshape)

            return arr

944
        # check if the requested array position is already in cache -> if yes, return it from there
945
        if self._arr_cache is not None and self._arr_cache['pos'] == arr_pos:
946
            out_arr = self._arr_cache['arr_cached']
947
            out_arr = _ensure_np_shape_consistency_3D_2D(out_arr)
948
949
950
951
952
953
954

        else:
            # TODO insert a multiprocessing.Lock here in order to prevent IO bottlenecks?
            # read subset area from disk
            if bL == list(range(0, B)):
                tempArr = gdalnumeric.LoadFile(path, cS, rS, cE - cS + 1, rE - rS + 1)
                out_arr = np.swapaxes(np.swapaxes(tempArr, 0, 2), 0, 1) if B > 1 else tempArr
955
956
                if out_arr is None:
                    raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
957
958
959
960
            else:
                ds = gdal.Open(path)
                if len(bL) == 1:
                    band = ds.GetRasterBand(bL[0] + 1)
961
                    out_arr = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
962
963
                    if out_arr is None:
                        raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
964
                    del band
965
966
967
968
969
                else:
                    out_arr = np.empty((rE - rS + 1, cE - cS + 1, len(bL)))
                    for i, bIdx in enumerate(bL):
                        band = ds.GetRasterBand(bIdx + 1)
                        out_arr[:, :, i] = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
970
971
                        if out_arr is None:
                            raise Exception('Error reading file:  ' + gdal.GetLastErrorMsg())
972
                        del band
973

974
                del ds
975

976
            out_arr = _ensure_np_shape_consistency_3D_2D(out_arr)
977

978
            # only set self.arr if the whole cube has been read (in order to avoid sudden shape changes)
979
            if out_arr.shape == self.shape:
980
981
982
983
984
                self.arr = out_arr

            # write _arr_cache
            self._arr_cache = dict(pos=arr_pos, arr_cached=out_arr)

985
        return out_arr  # TODO implement check of returned datatype (e.g. NoDataMask should always return bool
986
        # TODO -> would be np.int8 if an int8 file is read from disk
987
988
989
990
991
992
993
994

    def save(self, out_path, fmt='ENVI', creationOptions=None):
        # type: (str, str, list) -> None
        """Write the raster data to disk.

        :param out_path:        <str> output path
        :param fmt:             <str> the output format / GDAL driver code to be used for output creation, e.g. 'ENVI'
                                Refer to http://www.gdal.org/formats_list.html to get a full list of supported formats.
995
996
        :param creationOptions: <list> GDAL creation options,
                                e.g., ["QUALITY=80", "REVERSIBLE=YES", "WRITE_METADATA=YES"]
997
998
999
        """

        if not self.q:
1000
            print('Writing GeoArray of size %s to %s.' % (self.shape, out_path))
For faster browsing, not all history is shown. View entire blame