CoReg_local.py 36.8 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
# AROSICS - Automated and Robust Open-Source Image Co-Registration Software
#
5
# Copyright (C) 2017-2020  Daniel Scheffler (GFZ Potsdam, daniel.scheffler@gfz-potsdam.de)
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#
# This software was developed within the context of the GeoMultiSens project funded
# by the German Federal Ministry of Education and Research
# (project grant code: 01 IS 14 010 A-C).
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.

24
25
import warnings
import os
26
from copy import copy
27
from six import PY2
28
from typing import TYPE_CHECKING
29
30
31
32
33
34

# custom
try:
    import gdal
except ImportError:
    from osgeo import gdal
35
36
37
38
try:
    import pyfftw
except ImportError:
    pyfftw = None
39
import numpy as np
40
41
42

if TYPE_CHECKING:
    from matplotlib import pyplot as plt  # noqa F401
43

44
from .Tie_Point_Grid import Tie_Point_Grid
45
46
from .CoReg import COREG
from .DeShifter import DESHIFTER
47
from py_tools_ds.geo.coord_trafo import transform_any_prj, reproject_shapelyGeometry
48
from py_tools_ds.geo.map_info import geotransform2mapinfo
49
from geoarray import GeoArray
50

51
__author__ = 'Daniel Scheffler'
52
53
54


class COREG_LOCAL(object):
Daniel Scheffler's avatar
Daniel Scheffler committed
55
56
57
58
59
60
61
62
63
    """
    COREG_LOCAL applies the algorithm to detect spatial shifts to the whole overlap area of the input images.

    Spatial shifts are calculated for each point in grid of which the parameters can be adjusted using keyword
    arguments. Shift correction performs a polynomial transformation using the calculated shifts of each point in the
    grid as GCPs. Thus this class can be used to correct for locally varying geometric distortions of the target image.

    See help(COREG_LOCAL) for documentation.
    """
64

65
    def __init__(self, im_ref, im_tgt, grid_res, max_points=None, window_size=(256, 256), path_out=None, fmt_out='ENVI',
66
                 out_crea_options=None, projectDir=None, r_b4match=1, s_b4match=1, max_iter=5, max_shift=5,
67
68
69
70
71
72
                 tieP_filter_level=3, min_reliability=60, rs_max_outlier=10, rs_tolerance=2.5, align_grids=True,
                 match_gsd=False, out_gsd=None, target_xyGrid=None, resamp_alg_deshift='cubic', resamp_alg_calc='cubic',
                 footprint_poly_ref=None, footprint_poly_tgt=None, data_corners_ref=None, data_corners_tgt=None,
                 outFillVal=-9999, nodata=(None, None), calc_corners=True, binary_ws=True, force_quadratic_win=True,
                 mask_baddata_ref=None, mask_baddata_tgt=None, CPUs=None, progress=True, v=False, q=False,
                 ignore_errors=True):
Daniel Scheffler's avatar
Daniel Scheffler committed
73
        """Get an instance of COREG_LOCAL.
74
75

        :param im_ref(str, GeoArray):   source path of reference image (any GDAL compatible image format is supported)
76
77
        :param im_tgt(str, GeoArray):   source path of image to be shifted (any GDAL compatible image format is
                                        supported)
78
        :param grid_res:                tie point grid resolution in pixels of the target image (x-direction)
79
        :param max_points(int):         maximum number of points used to find coregistration tie points
80
81
82
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
83
        :param window_size(tuple):      custom matching window size [pixels] (default: (256,256))
84
        :param path_out(str):           target path of the coregistered image
85
                                            - if None (default), no output is written to disk
86
                                            - if 'auto': /dir/of/im1/<im1>__shifted_to__<im0>.bsq
87
        :param fmt_out(str):            raster file format for output file. ignored if path_out is None. Can be any GDAL
88
                                        compatible raster file format (e.g. 'ENVI', 'GTIFF'; default: ENVI). Refer to
89
                                        http://www.gdal.org/formats_list.html to get a full list of supported formats.
90
91
92
93
        :param out_crea_options(list):  GDAL creation options for the output image,
                                        e.g. ["QUALITY=80", "REVERSIBLE=YES", "WRITE_METADATA=YES"]
        :param projectDir(str):         name of a project directory where to store all the output results. If given,
                                        name is inserted into all automatically generated output paths.
94
95
96
97
        :param r_b4match(int):          band of reference image to be used for matching (starts with 1; default: 1)
        :param s_b4match(int):          band of shift image to be used for matching (starts with 1; default: 1)
        :param max_iter(int):           maximum number of iterations for matching (default: 5)
        :param max_shift(int):          maximum shift distance in reference image pixel units (default: 5 px)
Daniel Scheffler's avatar
Daniel Scheffler committed
98
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 3).
99
                                        NOTE: lower levels are also included if a higher level is chosen
100
                                            - Level 0: no tie point filtering
101
102
103
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
104
105
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
106
                                            - Level 3: RANSAC outlier detection
107
108
109
        :param min_reliability(float):  Tie point filtering: minimum reliability threshold, below which tie points are
                                        marked as false-positives (default: 60%)
                                        - accepts values between 0% (no reliability) and 100 % (perfect reliability)
110
111
                                        HINT: decrease this value in case of poor signal-to-noise ratio of your input
                                              data
112
113
114
        :param rs_max_outlier(float):   RANSAC tie point filtering: proportion of expected outliers (default: 10%)
        :param rs_tolerance(float):     RANSAC tie point filtering: percentage tolerance for max_outlier_percentage
                                                (default: 2.5%)
115
116
117
118
119
120
121
122
123
124
        :param out_gsd (float):         output pixel size in units of the reference coordinate system (default = pixel
                                        size of the input array), given values are overridden by match_gsd=True
        :param align_grids (bool):      True: align the input coordinate grid to the reference (does not affect the
                                        output pixel size as long as input and output pixel sizes are compatible
                                        (5:30 or 10:30 but not 4:30), default = True
        :param match_gsd (bool):        True: match the input pixel size to the reference pixel size,
                                        default = False
        :param target_xyGrid(list):     a list with a target x-grid and a target y-grid like [[15,45], [15,45]]
                                        This overrides 'out_gsd', 'align_grids' and 'match_gsd'.
        :param resamp_alg_deshift(str)  the resampling algorithm to be used for shift correction (if neccessary)
125
126
                                        valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average,
                                                          mode, max, min, med, q1, q3
127
128
129
                                        default: cubic
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
130
131
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average,
                                                           mode, max, min, med, q1, q3)
132
                                        default: cubic (highly recommended)
133
134
        :param footprint_poly_ref(str): footprint polygon of the reference image (WKT string or
                                        shapely.geometry.Polygon),
135
136
                                        e.g. 'POLYGON ((299999 6000000, 299999 5890200, 409799 5890200, 409799 6000000,
                                                        299999 6000000))'
137
138
        :param footprint_poly_tgt(str): footprint polygon of the image to be shifted (WKT string or
                                        shapely.geometry.Polygon)
139
                                        e.g. 'POLYGON ((299999 6000000, 299999 5890200, 409799 5890200, 409799 6000000,
140
141
142
143
144
                                                        299999 6000000))'
        :param data_corners_ref(list):  map coordinates of data corners within reference image.
                                        ignored if footprint_poly_ref is given.
        :param data_corners_tgt(list):  map coordinates of data corners within image to be shifted.
                                        ignored if footprint_poly_tgt is given.
145
        :param outFillVal(int):         if given the generated tie point grid is filled with this value in case
146
147
148
149
                                        no match could be found during co-registration (default: -9999)
        :param nodata(tuple):           no data values for reference image and image to be shifted
        :param calc_corners(bool):      calculate true positions of the dataset corners in order to get a useful
                                        matching window position within the actual image overlap
150
151
                                        (default: True; deactivated if 'data_corners_im0' and 'data_corners_im1' are
                                        given)
152
        :param binary_ws(bool):         use binary X/Y dimensions for the matching window (default: True)
Daniel Scheffler's avatar
Daniel Scheffler committed
153
        :param force_quadratic_win(bool):   force a quadratic matching window (default: 1)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        :param mask_baddata_ref(str, BadDataMask):
                                        path to a 2D boolean mask file (or an instance of BadDataMask) for the
                                        reference image where all bad data pixels (e.g. clouds) are marked with
                                        True and the remaining pixels with False. Must have the same geographic
                                        extent and projection like 'im_ref'. The mask is used to check if the
                                        chosen matching window position is valid in the sense of useful data.
                                        Otherwise this window position is rejected.
        :param mask_baddata_tgt(str, BadDataMask):
                                        path to a 2D boolean mask file (or an instance of BadDataMask) for the
                                        image to be shifted where all bad data pixels (e.g. clouds) are marked
                                        with True and the remaining pixels with False. Must have the same
                                        geographic extent and projection like 'im_ref'. The mask is used to
                                        check if the chosen matching window position is valid in the sense of
                                        useful data. Otherwise this window position is rejected.
168
        :param CPUs(int):               number of CPUs to use during calculation of tie point grid
169
                                        (default: None, which means 'all CPUs available')
170
171
172
        :param progress(bool):          show progress bars (default: True)
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
173
        :param ignore_errors(bool):     Useful for batch processing. (default: False)
174
        """
175
        # assertions / input validation
176
        assert gdal.GetDriverByName(fmt_out), "'%s' is not a supported GDAL driver." % fmt_out
177
178
179
180
        if match_gsd and out_gsd:
            warnings.warn("'-out_gsd' is ignored because '-match_gsd' is set.\n")
        if out_gsd:
            assert isinstance(out_gsd, list) and len(out_gsd) == 2, 'out_gsd must be a list with two values.'
181
        if PY2 and (CPUs is None or (isinstance(CPUs, int) and CPUs > 1)):
182
            CPUs = 1
Daniel Scheffler's avatar
Daniel Scheffler committed
183
            warnings.warn('Multiprocessing is currently not supported for Python 2. Using singleprocessing.')
184
185
186
187
188
189
190
191
192
193
194
195
196
197

        self.params = dict([x for x in locals().items() if x[0] != "self" and not x[0].startswith('__')])

        self.imref = GeoArray(im_ref, nodata=nodata[0], progress=progress, q=q)
        self.im2shift = GeoArray(im_tgt, nodata=nodata[1], progress=progress, q=q)
        self.path_out = path_out  # updated by self.set_outpathes
        self.fmt_out = fmt_out
        self.out_creaOpt = out_crea_options
        self._projectDir = projectDir
        self.grid_res = grid_res
        self.max_points = max_points
        self.window_size = window_size
        self.max_shift = max_shift
        self.max_iter = max_iter
198
        self.tieP_filter_level = tieP_filter_level
199
200
201
202
203
204
205
206
207
208
209
210
211
        self.min_reliability = min_reliability
        self.rs_max_outlier = rs_max_outlier
        self.rs_tolerance = rs_tolerance
        self.align_grids = align_grids
        self.match_gsd = match_gsd
        self.out_gsd = out_gsd
        self.target_xyGrid = target_xyGrid
        self.rspAlg_DS = resamp_alg_deshift  # TODO convert integers to strings
        self.rspAlg_calc = resamp_alg_calc
        self.calc_corners = calc_corners
        self.nodata = nodata
        self.outFillVal = outFillVal
        self.bin_ws = binary_ws
Daniel Scheffler's avatar
Daniel Scheffler committed
212
        self.force_quadratic_win = force_quadratic_win
213
214
215
216
217
218
        self.CPUs = CPUs
        self.path_verbose_out = ''  # TODO
        self.v = v
        self.q = q if not v else False  # overridden by v
        self.progress = progress if not q else False  # overridden by v
        self.ignErr = ignore_errors  # FIXME this is not yet implemented for COREG_LOCAL
219

220
        assert self.tieP_filter_level in range(4), 'Invalid tie point filter level.'
221
222
        assert isinstance(self.imref, GeoArray) and isinstance(self.im2shift, GeoArray), \
            'Something went wrong with the creation of GeoArray instances for reference or target image. The created ' \
223
224
            'instances do not seem to belong to the GeoArray class. If you are working in Jupyter Notebook, reset ' \
            'the kernel and try again.'
225

226
        COREG.__dict__['_set_outpathes'](self, self.imref, self.im2shift)
227
228
        # make sure that the output directory of coregistered image is the project directory if a project directory is
        # given
229
230
231
232
        if path_out and projectDir and os.path.basename(self.path_out):
            self.path_out = os.path.join(self.projectDir, os.path.basename(self.path_out))

        gdal.AllRegister()
233
234

        try:
235
            # ignore_errors must be False because in case COREG init fails, coregistration for the whole scene fails
236
            self.COREG_obj = COREG(self.imref, self.im2shift,
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
                                   ws=window_size,
                                   footprint_poly_ref=footprint_poly_ref,
                                   footprint_poly_tgt=footprint_poly_tgt,
                                   data_corners_ref=data_corners_ref,
                                   data_corners_tgt=data_corners_tgt,
                                   resamp_alg_calc=self.rspAlg_calc,
                                   calc_corners=calc_corners,
                                   r_b4match=r_b4match,
                                   s_b4match=s_b4match,
                                   max_iter=max_iter,
                                   max_shift=max_shift,
                                   nodata=nodata,
                                   mask_baddata_ref=None,  # see below
                                   mask_baddata_tgt=None,
                                   CPUs=self.CPUs,
                                   force_quadratic_win=self.force_quadratic_win,
                                   binary_ws=self.bin_ws,
                                   progress=self.progress,
                                   v=v,
                                   q=q,
                                   ignore_errors=False)
        except Exception:
            warnings.warn('\nFirst attempt to check if functionality of co-registration failed. Check your '
                          'input data and parameters. The following error occurred:', stacklevel=3)
            raise
262

263
264
        if pyfftw:
            self.check_if_fftw_works()
265

266
267
        # add bad data mask
        # (mask is not added during initialization of COREG object in order to avoid bad data area errors there)
268
269
270
271
        if mask_baddata_ref is not None:
            self.COREG_obj.ref.mask_baddata = mask_baddata_ref
        if mask_baddata_tgt is not None:
            self.COREG_obj.shift.mask_baddata = mask_baddata_tgt
272

273
        self._tiepoint_grid = None  # set by self.tiepoint_grid
274
275
276
277
        self._CoRegPoints_table = None  # set by self.CoRegPoints_table
        self._coreg_info = None  # set by self.coreg_info
        self.deshift_results = None  # set by self.correct_shifts()
        self._success = None  # set by self.success property
278

279
    def check_if_fftw_works(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
280
        """Assign the attribute 'fftw_works' to self.COREG_obj by executing shift calculation once with muted output."""
281
        # calculate global shift once in order to check is fftw works
282
283
284
285
        try:
            self.COREG_obj.q = True
            self.COREG_obj.v = False
            self.COREG_obj.calculate_spatial_shifts()
286
        except RuntimeError:
287
288
289
            if self.COREG_obj.fftw_works is not None:
                pass
            else:
290
291
292
                warnings.warn('\nFirst attempt to check if functionality of co-registration failed. Check your '
                              'input data and parameters. The following error occurred:', stacklevel=3)
                raise
293

Daniel Scheffler's avatar
Daniel Scheffler committed
294
295
296
        self.COREG_obj.q = self.q
        self.COREG_obj.v = self.v

297
298
299
    @property
    def projectDir(self):
        if self._projectDir:
300
            if len(os.path.split(self._projectDir)) == 1:
301
302
303
304
305
                return os.path.abspath(os.path.join(os.path.curdir, self._projectDir))
            else:
                return os.path.abspath(self._projectDir)
        else:
            # return a project name that not already has a corresponding folder on disk
306
            root_dir = os.path.dirname(self.im2shift.filePath) if self.im2shift.filePath else os.path.curdir
307
308
309
310
311
312
            fold_name = 'UntitledProject_1'

            while os.path.isdir(os.path.join(root_dir, fold_name)):
                fold_name = '%s_%s' % (fold_name.split('_')[0], int(fold_name.split('_')[-1]) + 1)

            self._projectDir = os.path.join(root_dir, fold_name)
313
314
315
            return self._projectDir

    @property
Daniel Scheffler's avatar
Daniel Scheffler committed
316
317
318
    def tiepoint_grid(self):
        if self._tiepoint_grid:
            return self._tiepoint_grid
319
        else:
Daniel Scheffler's avatar
Daniel Scheffler committed
320
            self._tiepoint_grid = Tie_Point_Grid(self.COREG_obj, self.grid_res,
321
322
323
324
325
326
327
328
329
330
331
332
333
                                                 max_points=self.max_points,
                                                 outFillVal=self.outFillVal,
                                                 resamp_alg_calc=self.rspAlg_calc,
                                                 tieP_filter_level=self.tieP_filter_level,
                                                 outlDetect_settings=dict(
                                                     min_reliability=self.min_reliability,
                                                     rs_max_outlier=self.rs_max_outlier,
                                                     rs_tolerance=self.rs_tolerance),
                                                 dir_out=self.projectDir,
                                                 CPUs=self.CPUs,
                                                 progress=self.progress,
                                                 v=self.v,
                                                 q=self.q)
334
335
            self._tiepoint_grid.get_CoRegPoints_table()

336
            if self.v:
Daniel Scheffler's avatar
Daniel Scheffler committed
337
                print('Visualizing CoReg points grid...')
338
                self.view_CoRegPoints(figsize=(10, 10))
Daniel Scheffler's avatar
Daniel Scheffler committed
339
            return self._tiepoint_grid
340
341
342

    @property
    def CoRegPoints_table(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
343
        """Return a GeoDataFrame containing all the results from coregistration for all points in the tie point grid.
Daniel Scheffler's avatar
Daniel Scheffler committed
344

Daniel Scheffler's avatar
Daniel Scheffler committed
345
346
347
        Columns of the GeoDataFrame: 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE', 'Y_WIN_SIZE',
                                     'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE'
        """
Daniel Scheffler's avatar
Daniel Scheffler committed
348
        return self.tiepoint_grid.CoRegPoints_table
349

350
351
    @property
    def success(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
352
        self._success = self.tiepoint_grid.GCPList != []
353
354
355
356
        if not self._success and not self.q:
            warnings.warn('No valid GCPs could by identified.')
        return self._success

357
    def show_image_footprints(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
358
359
360
361
        """Show a web map containing the calculated footprints and overlap area of the input images.

        NOTE: This method is intended to be called from Jupyter Notebook.
        """
362
363
        return self.COREG_obj.show_image_footprints()

364
365
    def view_CoRegPoints(self, shapes2plot='points', attribute2plot='ABS_SHIFT', cmap=None, exclude_fillVals=True,
                         backgroundIm='tgt', hide_filtered=True, figsize=None, title='', vector_scale=1.,
366
367
                         savefigPath='', savefigDPI=96, showFig=True, vmin=None, vmax=None, return_map=False):
        # type: (str, str, plt.cm, bool, str, bool, tuple, str, float, str, int, bool, float, float, bool) -> ...
Daniel Scheffler's avatar
Daniel Scheffler committed
368
369
        """
        Show a map of the calculated tie point grid with the target image as background.
370

371
372
        :param shapes2plot:         <str> 'points': plot points representing values of 'attribute2plot' onto the map
                                          'vectors': plot shift vectors onto the map
373
        :param attribute2plot:      <str> the attribute of the tie point grid to be shown (default: 'ABS_SHIFT')
374
375
        :param cmap:                <plt.cm.<colormap>> a custom color map to be applied to the plotted grid points
                                                        (default: 'RdYlGn_r')
376
377
        :param exclude_fillVals:    <bool> whether to exclude those points of the grid where spatial shift detection
                                    failed
378
379
        :param backgroundIm:        <str> whether to use the target or the reference image as map background. Possible
                                          options are 'ref' and 'tgt' (default: 'tgt')
380
381
        :param hide_filtered:       <bool> hide all points that have been filtered out according to tie point filter
                                    level
382
        :param figsize:             <tuple> size of the figure to be viewed, e.g. (10,10)
383
384
        :param title:               <str> plot title
        :param vector_scale:        <float> scale factor for shift vector length (default: 1 -> no scaling)
385
386
        :param savefigPath:
        :param savefigDPI:
387
        :param showFig:             <bool> whether to show or to hide the figure
388
389
        :param vmin:
        :param vmax:
390
        :param return_map           <bool
391
392
        :return:
        """
393
394
        from matplotlib import pyplot as plt  # noqa
        from matplotlib.colors import Normalize
395
396
        from cartopy.crs import PlateCarree
        from mpl_toolkits.axes_grid1 import make_axes_locatable
397

398
        # get a map showing the reference or target image
399
400
401
        if backgroundIm not in ['tgt', 'ref']:
            raise ValueError('backgroundIm')
        backgroundIm = self.im2shift if backgroundIm == 'tgt' else self.imref
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        fig, ax = backgroundIm.show_map(figsize=figsize, nodataVal=self.nodata[1], return_map=True,
                                        band=self.COREG_obj.shift.band4match)

        # set figure title
        dict_attr_title = dict(
            X_WIN_SIZE='size of the matching window in x-direction [pixels]',
            Y_WIN_SIZE='size of the matching window in y-direction [pixels]',
            X_SHIFT_PX='absolute shifts in x-direction [pixels]',
            Y_SHIFT_PX='absolute shifts in y-direction [pixels]',
            X_SHIFT_M='absolute shifts in x-direction [map units]',
            Y_SHIFT_M='absolute shifts in y-direction [map units]',
            ABS_SHIFT='absolute shift vector length [map units]',
            ANGLE='shift vector direction [angle in degrees]',
            SSIM_BEFORE='structural similarity index before co-registration',
            SSIM_AFTER='structural similarity index after co-registration',
            SSIM_IMPROVED='structural similarity index improvement through co-registration [yes/no]',
            RELIABILITY='reliability of the computed shift vector'
        )
        if title:
            ax.set_title(title)
        elif attribute2plot in dict_attr_title:
            ax.set_title(dict_attr_title[attribute2plot], pad=20)
        elif attribute2plot in self.CoRegPoints_table.columns:
            ax.set_title(attribute2plot)
        else:
            raise ValueError(attribute2plot, "Invalid value for 'attribute2plot'. Valid values are: %s."
                             % ", ".join(self.CoRegPoints_table.columns))
429

430
        # get GeoDataFrame containing everything needed for plotting
431
        outlierCols = [c for c in self.CoRegPoints_table.columns if 'OUTLIER' in c]
432
        attr2include = ['geometry', attribute2plot] + outlierCols + ['X_SHIFT_M', 'Y_SHIFT_M']
433
        GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.X_SHIFT_M != self.outFillVal, attr2include].copy()\
434
            if exclude_fillVals else self.CoRegPoints_table.loc[:, attr2include]
435
436

        # get LonLat coordinates for all points
437
438
439
        XY = np.array([geom.coords.xy for geom in GDF.geometry]).reshape(-1, 2)
        lon, lat = transform_any_prj(self.im2shift.projection, 4326, XY[:, 0], XY[:, 1])
        GDF['Lon'], GDF['Lat'] = lon, lat
440
441

        # get colors for all points
442
        palette = cmap if cmap is not None else plt.cm.get_cmap('RdYlGn_r')
443
444
        if cmap is None and attribute2plot == 'ANGLE':
            import cmocean
445
            palette = getattr(cmocean.cm, 'delta')
446

447
        if hide_filtered:
448
            if self.tieP_filter_level > 0:
449
                GDF = GDF[GDF.L1_OUTLIER.__eq__(False)].copy()
450
            if self.tieP_filter_level > 1:
451
                GDF = GDF[GDF.L2_OUTLIER.__eq__(False)].copy()
452
            if self.tieP_filter_level > 2:
453
                GDF = GDF[GDF.L3_OUTLIER.__eq__(False)].copy()
454
455
        else:
            marker = 'o' if len(GDF) < 10000 else '.'
456
            common_kw = dict(marker=marker, alpha=1.0, transform=PlateCarree())
457
458
            if self.tieP_filter_level > 0:
                # flag level 1 outliers
459
                GDF_filt = GDF[GDF.L1_OUTLIER.__eq__(True)].copy()
460
461
                ax.scatter(GDF_filt['Lon'], GDF_filt['Lat'], c='b', s=250, label='reliability',
                           **common_kw)
462
463
            if self.tieP_filter_level > 1:
                # flag level 2 outliers
464
                GDF_filt = GDF[GDF.L2_OUTLIER.__eq__(True)].copy()
465
466
                ax.scatter(GDF_filt['Lon'], GDF_filt['Lat'], c='r', s=150, label='SSIM',
                           **common_kw)
467
468
            if self.tieP_filter_level > 2:
                # flag level 3 outliers
469
                GDF_filt = GDF[GDF.L3_OUTLIER.__eq__(True)].copy()
470
471
                ax.scatter(GDF_filt['Lon'], GDF_filt['Lat'], c='y', s=250, label='RANSAC',
                           **common_kw)
472
473

            if self.tieP_filter_level > 0:
474
                ax.legend(loc=0, scatterpoints=1)
475

476
        # plot all points or vectors on top
Daniel Scheffler's avatar
Daniel Scheffler committed
477
        if not GDF.empty:
478
479
480
            vmin_auto, vmax_auto = \
                (np.percentile(GDF[attribute2plot], 0),
                 np.percentile(GDF[attribute2plot], 98)) \
481
                if attribute2plot != 'ANGLE' else (0, 360)
482
483
484
            vmin = vmin if vmin is not None else vmin_auto
            vmax = vmax if vmax is not None else vmax_auto

485
486
487
            if shapes2plot == 'vectors':
                # plot shift vectors
                # doc: https://matplotlib.org/devdocs/api/_as_gen/matplotlib.axes.Axes.quiver.html
488
489
490
491
492
493
494
495
496
497
498
499
                mappable = ax.quiver(
                    GDF['Lon'].values, GDF['Lat'].values,
                    -GDF['X_SHIFT_M'].values,
                    -GDF['Y_SHIFT_M'].values,  # invert absolute shifts to make arrows point to tgt
                    GDF[attribute2plot].clip(vmin, vmax),  # sets the colors
                    scale=1200 / vector_scale,  # larger values decrease the arrow length
                    width=.0015,  # arrow width (in relation to plot width)
                    # linewidth=1, # maybe use this to mark outliers instead of scatter points
                    cmap=palette,
                    pivot='middle',  # position the middle point of the arrows onto the tie point location
                    transform=PlateCarree()
                    )
500
501
502

            elif shapes2plot == 'points':
                # plot tie points
503
504
505
506
507
508
509
510
511
512
513
514
                mappable = ax.scatter(
                    GDF['Lon'], GDF['Lat'],
                    c=GDF[attribute2plot],
                    lw=0,
                    cmap=palette,
                    marker='o' if len(GDF) < 10000 else '.',
                    s=50,
                    alpha=1.0,
                    vmin=vmin,
                    vmax=vmax,
                    transform=PlateCarree())
                pass
515
            else:
516
517
                raise ValueError("The parameter 'shapes2plot' must be set to 'vectors' or 'points'. "
                                 "Received %s." % shapes2plot)
518

Daniel Scheffler's avatar
Daniel Scheffler committed
519
            # add colorbar
520
521
522
523
524
525
526
527
528
529
            divider = make_axes_locatable(ax)
            cax = divider.new_vertical(size="2%", pad=0.4, pack_start=True,
                                       axes_class=plt.Axes  # needed because ax is a GeoAxis instance
                                       )
            fig.add_axes(cax)
            fig.colorbar(mappable, cax=cax, cmap=palette,
                         norm=Normalize(vmin=vmin, vmax=vmax), orientation="horizontal")

            # hack to enlarge the figure on the top to avoid cutting off the title (everthing else has no effect)
            divider.new_vertical(size="2%", pad=0.4, pack_start=False, axes_class=plt.Axes)
530

Daniel Scheffler's avatar
Daniel Scheffler committed
531
532
533
534
        else:
            if not self.q:
                warnings.warn('Cannot plot any tie point because none is left after tie point validation.')

535
536
537
        if savefigPath:
            fig.savefig(savefigPath, dpi=savefigDPI)

538
        if return_map:
539
            return fig, ax
540

541
542
543
544
545
        if showFig and not self.q:
            plt.show(block=True)
        else:
            plt.close(fig)

546
547
    def view_CoRegPoints_folium(self, attribute2plot='ABS_SHIFT', cmap=None, exclude_fillVals=True):
        warnings.warn(UserWarning('This function is still under construction and may not work as expected!'))
548
        assert self.CoRegPoints_table is not None, 'Calculate tie point grid first!'
549

550
551
        import folium
        import geojson
552
        from folium.raster_layers import ImageOverlay
553

554
555
        lon_min, lat_min, lon_max, lat_max = \
            reproject_shapelyGeometry(self.im2shift.box.mapPoly, self.im2shift.projection, 4326).bounds
556
        center_lon, center_lat = (lon_min + lon_max) / 2, (lat_min + lat_max) / 2
557
558

        # get image to plot
559
        image2plot = self.im2shift[:, :, 0]  # FIXME hardcoded band
560

561
        from py_tools_ds.geo.raster.reproject import warp_ndarray
562
563
564
565
        image2plot, gt, prj = \
            warp_ndarray(image2plot, self.im2shift.geotransform, self.im2shift.projection,
                         in_nodata=self.nodata[1], out_nodata=self.nodata[1], out_XYdims=(1000, 1000), q=True,
                         out_prj='epsg:3857')  # image must be transformed into web mercator projection
566
567

        # create map
568
        map_osm = folium.Map(location=[center_lat, center_lon])  # ,zoom_start=3)
569
        # import matplotlib
570
        ImageOverlay(
571
            colormap=lambda x: (1, 0, 0, x),  # TODO a colormap must be given
572
573
            # colormap=matplotlib.cm.gray, # does not work
            image=image2plot, bounds=[[lat_min, lon_min], [lat_max, lon_max]],
574
        ).add_to(map_osm)
575

576
577
578
        points_values = self.CoRegPoints_table[['geometry', attribute2plot]]
        points_values.geometry.crs = points_values.crs
        folium.GeoJson(points_values).add_to(map_osm)
579
580
581

        # add overlap polygon
        overlapPoly = reproject_shapelyGeometry(self.COREG_obj.overlap_poly, self.im2shift.epsg, 4326)
582
        gjs = geojson.Feature(geometry=overlapPoly, properties={})
583
584
585
586
        folium.GeoJson(gjs).add_to(map_osm)

        return map_osm

587
    def _get_updated_map_info_meanShifts(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
588
        """Return the updated map info of the target image, shifted on the basis of the mean X/Y shifts."""
589
590
        original_map_info = geotransform2mapinfo(self.im2shift.gt, self.im2shift.prj)
        updated_map_info = copy(original_map_info)
591
592
593
594
        updated_map_info[3] = str(float(original_map_info[3]) + self.tiepoint_grid.mean_x_shift_map)
        updated_map_info[4] = str(float(original_map_info[4]) + self.tiepoint_grid.mean_y_shift_map)
        return updated_map_info

595
596
    @property
    def coreg_info(self):
Daniel Scheffler's avatar
Daniel Scheffler committed
597
        """Return a dictionary containing everthing to correct the detected local displacements of the target image."""
598
599
600
601
        if self._coreg_info:
            return self._coreg_info
        else:
            self._coreg_info = {
602
603
604
605
606
                'GCPList': self.tiepoint_grid.GCPList,
                'mean_shifts_px': {'x': self.tiepoint_grid.mean_x_shift_px,
                                   'y': self.tiepoint_grid.mean_y_shift_px},
                'mean_shifts_map': {'x': self.tiepoint_grid.mean_x_shift_map,
                                    'y': self.tiepoint_grid.mean_y_shift_map},
607
                'updated map info means': self._get_updated_map_info_meanShifts(),
608
609
                'original map info': geotransform2mapinfo(self.imref.gt, self.imref.prj),
                'reference projection': self.imref.prj,
Daniel Scheffler's avatar
Daniel Scheffler committed
610
                'reference geotransform': self.imref.gt,
611
612
613
614
                'reference grid': [[self.imref.gt[0], self.imref.gt[0] + self.imref.gt[1]],
                                   [self.imref.gt[3], self.imref.gt[3] + self.imref.gt[5]]],
                'reference extent': {'cols': self.imref.xgsd, 'rows': self.imref.ygsd},  # FIXME not needed anymore
                'success': self.success
615
616
617
            }
            return self.coreg_info

618
    def correct_shifts(self, max_GCP_count=None, cliptoextent=False, min_points_local_corr=5):
Daniel Scheffler's avatar
Daniel Scheffler committed
619
620
621
        """Perform a local shift correction using all points from the previously calculated tie point grid.

        NOTE: Only valid matches are used as GCP points.
622
623

        :param max_GCP_count: <int> maximum number of GCPs to use
624
        :param cliptoextent:  <bool> whether to clip the output image to its real extent
625
626
627
        :param min_points_local_corr:   <int> number of valid tie points, below which a global shift correction is
                                        performed instead of a local correction (global X/Y shift is then computed as
                                        the mean shift of the remaining points)(default: 5 tie points)
628
629
        :return:
        """
630
        coreg_info = self.coreg_info
631

Daniel Scheffler's avatar
Daniel Scheffler committed
632
        if self.tiepoint_grid.GCPList:
633
            if max_GCP_count:
634
                coreg_info['GCPList'] = coreg_info['GCPList'][:max_GCP_count]
635
636

            DS = DESHIFTER(self.im2shift, coreg_info,
637
638
639
640
641
642
643
644
645
646
                           path_out=self.path_out,
                           fmt_out=self.fmt_out,
                           out_crea_options=self.out_creaOpt,
                           align_grids=self.align_grids,
                           match_gsd=self.match_gsd,
                           out_gsd=self.out_gsd,
                           target_xyGrid=self.target_xyGrid,
                           min_points_local_corr=min_points_local_corr,
                           resamp_alg=self.rspAlg_DS,
                           cliptoextent=cliptoextent,
647
                           # clipextent=self.im2shift.box.boxMapYX,
648
649
650
                           progress=self.progress,
                           v=self.v,
                           q=self.q)
651
652
653
654
655
656

            self.deshift_results = DS.correct_shifts()
            return self.deshift_results
        else:
            if not self.q:
                warnings.warn('Correction of geometric shifts failed because the input GCP list is empty!')