Geom_Quality_Grid.py 34.2 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
__author__='Daniel Scheffler'

import collections
import multiprocessing
import os
import warnings
8
import time
9
10

# custom
11
12
13
14
try:
    import gdal
except ImportError:
    from osgeo import gdal
15
import numpy as np
16
from geopandas         import GeoDataFrame, GeoSeries
17
18
from pykrige.ok        import OrdinaryKriging
from shapely.geometry  import Point
19
20
from skimage.measure   import points_in_poly, ransac
from skimage.transform import AffineTransform, PolynomialTransform
21
22

# internal modules
23
from .CoReg  import COREG
24
25
26
from .       import io    as IO
from py_tools_ds.ptds.geo.projection          import isProjectedOrGeographic, get_UTMzone
from py_tools_ds.ptds.io.pathgen              import get_generic_outpath
27
from py_tools_ds.ptds.processing.progress_mon import ProgressBar
28
29
30
31
32
33
34
35



global_shared_imref    = None
global_shared_im2shift = None


class Geom_Quality_Grid(object):
36
37
    """See help(Geom_Quality_Grid) for documentation!"""

38
39
    def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
                 tieP_filter_level=2, dir_out=None, CPUs=None, progress=True, v=False, q=False):
40

41
42
43
44
        """Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
        are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
        correction performs a polynomial transformation using te calculated shifts of each point in the grid as GCPs.
        Thus 'Geom_Quality_Grid' can be used to correct for locally varying geometric distortions of the target image.
45

46
        :param COREG_obj(object):       an instance of COREG class
47
        :param grid_res:                grid resolution in pixels of the target image
48
        :param max_points(int):         maximum number of points used to find coregistration tie points
49
50
51
                                        NOTE: Points are selected randomly from the given point grid (specified by
                                        'grid_res'). If the point does not provide enough points, all available points
                                        are chosen.
52
53
        :param outFillVal(int):         if given the generated geometric quality grid is filled with this value in case
                                        no match could be found during co-registration (default: -9999)
54
55
56
57
58
        :param resamp_alg_calc(str)     the resampling algorithm to be used for all warping processes during calculation
                                        of spatial shifts
                                        (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
                                                       max, min, med, q1, q3)
                                        default: cubic (highly recommended)
59
60
        :param tieP_filter_level(int):  filter tie points used for shift correction in different levels (default: 2).
                                        NOTE: lower levels are also included if a higher level is chosen
61
                                            - Level 0: no tie point filtering
62
63
64
                                            - Level 1: Reliablity filtering - filter all tie points out that have a low
                                                reliability according to internal tests
                                            - Level 2: SSIM filtering - filters all tie points out where shift
65
66
                                                correction does not increase image similarity within matching window
                                                (measured by mean structural similarity index)
67
                                            - Level 3: RANSAC outlier detection
68
69
        :param dir_out(str):            output directory to be used for all outputs if nothing else is given
                                        to the individual methods
70
71
        :param CPUs(int):               number of CPUs to use during calculation of geometric quality grid
                                        (default: None, which means 'all CPUs available')
72
        :param progress(bool):          show progress bars (default: True)
73
74
        :param v(bool):                 verbose mode (default: False)
        :param q(bool):                 quiet mode (default: False)
75
        """
76
77
78

        if not isinstance(COREG_obj, COREG): raise ValueError("'COREG_obj' must be an instance of COREG class.")

79
80
        self.COREG_obj         = COREG_obj
        self.grid_res          = grid_res
81
        self.max_points        = max_points
82
83
84
85
86
87
88
89
        self.outFillVal        = outFillVal
        self.rspAlg_calc       = resamp_alg_calc
        self.tieP_filter_level = tieP_filter_level
        self.dir_out           = dir_out
        self.CPUs              = CPUs
        self.v                 = v
        self.q                 = q        if not v else False # overridden by v
        self.progress          = progress if not q else False # overridden by q
90

91
92
        self.ref               = self.COREG_obj.ref
        self.shift             = self.COREG_obj.shift
93
94

        self.XY_points, self.XY_mapPoints = self._get_imXY__mapXY_points(self.grid_res)
95
96
97
        self._CoRegPoints_table           = None # set by self.CoRegPoints_table
        self._GCPList                     = None # set by self.to_GCPList()
        self.kriged                       = None # set by Raster_using_Kriging()
98
99


100
101
    @property
    def CoRegPoints_table(self):
102
103
104
105
        """Returns a GeoDataFrame with the columns 'geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM','X_WIN_SIZE',
        'Y_WIN_SIZE','X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M', 'Y_SHIFT_M', 'ABS_SHIFT' and 'ANGLE' containing all
        information containing all the results frm coregistration for all points in the geometric quality grid.
        """
106
107
108
109
110
111
        if self._CoRegPoints_table is not None:
            return self._CoRegPoints_table
        else:
            self._CoRegPoints_table = self.get_CoRegPoints_table()
            return self._CoRegPoints_table

112

113
114
115
116
117
118
119
    @CoRegPoints_table.setter
    def CoRegPoints_table(self, CoRegPoints_table):
        self._CoRegPoints_table = CoRegPoints_table


    @property
    def GCPList(self):
120
121
        """Returns a list of GDAL compatible GCP objects.
        """
122
123
124
125
        if self._GCPList:
            return self._GCPList
        else:
            self._GCPList = self.to_GCPList()
126
            return self._GCPList
127
128
129
130
131
132
133
134


    @GCPList.setter
    def GCPList(self, GCPList):
        self._GCPList = GCPList


    def _get_imXY__mapXY_points(self, grid_res):
135
136
137
138
139
140
        """Returns a numpy array containing possible positions for coregistration tie points according to the given
        grid resolution.

        :param grid_res:
        :return:
        """
141
142
143
        if not self.q:
            print('Initializing geometric quality grid...')

144
145
146
147
        Xarr,Yarr       = np.meshgrid(np.arange(0,self.shift.shape[1],grid_res),
                                      np.arange(0,self.shift.shape[0],grid_res))

        ULmapYX, URmapYX, LRmapYX, LLmapYX = self.shift.box.boxMapYX
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

        mapXarr,mapYarr = np.meshgrid(np.arange(ULmapYX[1],LRmapYX[1],     self.grid_res*self.COREG_obj.shift.xgsd),
                                      np.arange(ULmapYX[0],LRmapYX[0],-abs(self.grid_res*self.COREG_obj.shift.ygsd)))

        XY_points      = np.empty((Xarr.size,2),Xarr.dtype)
        XY_points[:,0] = Xarr.flat
        XY_points[:,1] = Yarr.flat

        XY_mapPoints      = np.empty((mapXarr.size,2),mapXarr.dtype)
        XY_mapPoints[:,0] = mapXarr.flat
        XY_mapPoints[:,1] = mapYarr.flat

        return XY_points,XY_mapPoints


163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    def _exclude_bad_XYpos(self, GDF):
        """Excludes all points outside of the image overlap area and all points where the bad data mask is True (if given).

        :param GDF:     <geopandas.GeoDataFrame> must include the columns 'X_UTM' and 'Y_UTM'
        :return:
        """

        # exclude all points outside of overlap area
        inliers = points_in_poly(self.XY_mapPoints,
                                 np.swapaxes(np.array(self.COREG_obj.overlap_poly.exterior.coords.xy), 0, 1))
        GDF = GDF[inliers].copy()
        #GDF = GDF[GDF['geometry'].within(self.COREG_obj.overlap_poly.simplify(tolerance=15))] # works but much slower

        assert not GDF.empty, 'No coregistration point could be placed within the overlap area. Check your input data!' # FIXME track that


        # exclude all point where bad data mask is True (e.g. points on clouds etc.)
        orig_len_GDF       = len(GDF)
        mapXY              = np.array(GDF.loc[:,['X_UTM','Y_UTM']])
        GDF['REF_BADDATA'] = self.COREG_obj.ref  .mask_baddata.read_pointData(mapXY) \
                                if self.COREG_obj.ref  .mask_baddata is not None else False
        GDF['TGT_BADDATA'] = self.COREG_obj.shift.mask_baddata.read_pointData(mapXY)\
                                if self.COREG_obj.shift.mask_baddata is not None else False
        GDF                = GDF[(GDF['REF_BADDATA']==False) & (GDF['TGT_BADDATA']==False)]
        if self.COREG_obj.ref.mask_baddata is not None or self.COREG_obj.shift.mask_baddata is not None:
            print('According to the provided bad data mask(s) %s points of initially %s have been excluded.'
                  %(orig_len_GDF-len(GDF), orig_len_GDF))

        return GDF


194
195
196
197
198
    @staticmethod
    def _get_spatial_shifts(coreg_kwargs):
        pointID = coreg_kwargs['pointID']
        del coreg_kwargs['pointID']

199
200
        #for im in [global_shared_imref, global_shared_im2shift]:
        #    imX, imY = mapXY2imXY(coreg_kwargs['wp'], im.gt)
201
        #    if im[int(imY), int(imX), im.band4match]==im.nodata,\
202
        #        return
203
204
        assert global_shared_imref    is not None
        assert global_shared_im2shift is not None
205
        CR = COREG(global_shared_imref, global_shared_im2shift, multiproc=False, **coreg_kwargs)
206
        CR.calculate_spatial_shifts()
207
        last_err           = CR.tracked_errors[-1] if CR.tracked_errors else None
208
209
210
        win_sz_y, win_sz_x = CR.matchBox.imDimsYX if CR.matchBox else (None, None)
        CR_res   = [win_sz_x, win_sz_y, CR.x_shift_px, CR.y_shift_px, CR.x_shift_map, CR.y_shift_map,
                    CR.vec_length_map, CR.vec_angle_deg, CR.ssim_orig, CR.ssim_deshifted, CR.ssim_improved,
211
                    CR.shift_reliability, last_err]
212
213

        return [pointID]+CR_res
214
215


216
    def get_CoRegPoints_table(self):
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        assert self.XY_points is not None and self.XY_mapPoints is not None

        #ref_ds,tgt_ds = gdal.Open(self.path_imref),gdal.Open(self.path_im2shift)
        #ref_pathTmp, tgt_pathTmp = None,None
        #if ref_ds.GetDriver().ShortName!='ENVI':
        #    ref_pathTmp = IO.get_tempfile(ext='.bsq')
        #    IO.convert_gdal_to_bsq__mp(self.path_imref,ref_pathTmp)
        #    self.path_imref = ref_pathTmp
        #if tgt_ds.GetDriver().ShortName!='ENVI':
        #    tgt_pathTmp = IO.get_tempfile(ext='.bsq')
        #    IO.convert_gdal_to_bsq__mp(self.path_im2shift,tgt_pathTmp)
        #    self.path_im2shift = tgt_pathTmp
        #ref_ds=tgt_ds=None

231
        XYarr2PointGeom = np.vectorize(lambda X,Y: Point(X,Y), otypes=[Point])
232
233
234
235
236
237
238
        geomPoints      = np.array(XYarr2PointGeom(self.XY_mapPoints[:,0],self.XY_mapPoints[:,1]))

        if isProjectedOrGeographic(self.COREG_obj.shift.prj)=='geographic':
            crs = dict(ellps='WGS84', datum='WGS84', proj='longlat')
        elif isProjectedOrGeographic(self.COREG_obj.shift.prj)=='projected':
            UTMzone = abs(get_UTMzone(prj=self.COREG_obj.shift.prj))
            south   = get_UTMzone(prj=self.COREG_obj.shift.prj)<0
239
            crs     = dict(ellps='WGS84', datum='WGS84', proj='utm', zone=UTMzone, south=south, units='m', no_defs=True)
240
241
242
243
244
            if not south: del crs['south']
        else:
            crs = None

        GDF        = GeoDataFrame(index=range(len(geomPoints)),crs=crs,
245
                                  columns=['geometry','POINT_ID','X_IM','Y_IM','X_UTM','Y_UTM'])
246
247
        GDF       ['geometry']       = geomPoints
        GDF       ['POINT_ID']       = range(len(geomPoints))
248
        GDF.loc[:,['X_IM' ,'Y_IM' ]] = self.XY_points
249
        GDF.loc[:,['X_UTM','Y_UTM']] = self.XY_mapPoints
250

251
252
        # exclude offsite points and points on bad data mask
        GDF = self._exclude_bad_XYpos(GDF)
253

254
255
256
        # choose a random subset of points if a maximum number has been given
        if self.max_points:
            GDF = GDF.sample(self.max_points).copy()
257
258
259

        # declare global variables needed for self._get_spatial_shifts()
        global global_shared_imref,global_shared_im2shift
260
261
262
263
264
265
        assert self.ref  .footprint_poly # this also checks for mask_nodata and nodata value
        assert self.shift.footprint_poly
        if not self.ref  .is_inmem: self.ref.cache_array_subset(self.ref  [self.COREG_obj.ref  .band4match])
        if not self.shift.is_inmem: self.ref.cache_array_subset(self.shift[self.COREG_obj.shift.band4match])
        global_shared_imref    = self.ref
        global_shared_im2shift = self.shift
266
267
268

        # get all variations of kwargs for coregistration
        get_coreg_kwargs = lambda pID, wp: {
269
270
271
            'pointID'            : pID,
            'wp'                 : wp,
            'ws'                 : self.COREG_obj.win_size_XY,
272
            'resamp_alg_calc'    : self.rspAlg_calc,
273
274
275
276
277
278
279
280
281
282
283
            'footprint_poly_ref' : self.COREG_obj.ref.poly,
            'footprint_poly_tgt' : self.COREG_obj.shift.poly,
            'r_b4match'          : self.COREG_obj.ref.band4match+1,   # band4match is internally saved as index, starting from 0
            's_b4match'          : self.COREG_obj.shift.band4match+1, # band4match is internally saved as index, starting from 0
            'max_iter'           : self.COREG_obj.max_iter,
            'max_shift'          : self.COREG_obj.max_shift,
            'nodata'             : (self.COREG_obj.ref.nodata, self.COREG_obj.shift.nodata),
            'binary_ws'          : self.COREG_obj.bin_ws,
            'v'                  : False, # otherwise this would lead to massive console output
            'q'                  : True,  # otherwise this would lead to massive console output
            'ignore_errors'      : True
284
285
286
287
        }
        list_coreg_kwargs = (get_coreg_kwargs(i, self.XY_mapPoints[i]) for i in GDF.index) # generator

        # run co-registration for whole grid
288
        if self.CPUs is None or self.CPUs>1:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
289
            if not self.q:
290
291
                cpus = self.CPUs if self.CPUs is not None else multiprocessing.cpu_count()
                print("Calculating geometric quality grid (%s points) using %s CPU cores..." %(len(GDF), cpus))
292

293
            with multiprocessing.Pool(self.CPUs) as pool:
294
295
296
297
                if self.q or not self.progress:
                    results = pool.map(self._get_spatial_shifts, list_coreg_kwargs)
                else:
                    results = pool.map_async(self._get_spatial_shifts, list_coreg_kwargs, chunksize=1)
298
                    bar     = ProgressBar(prefix='\tprogress:')
299
300
301
                    while True:
                        time.sleep(.1)
                        numberDone = len(GDF)-results._number_left # this does not really represent the remaining tasks but the remaining chunks -> thus chunksize=1
302
303
                        if self.progress:
                            bar.print_progress(percent=numberDone/len(GDF)*100)
304
                        if results.ready():
305
                            results = results.get() # FIXME in some cases the code hangs here ==> x ==> seems to be fixed
306
                            break
307
        else:
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
308
            if not self.q:
309
                print("Calculating geometric quality grid (%s points) 1 CPU core..." %len(GDF))
310
            results = np.empty((len(geomPoints),14), np.object)
311
            bar     = ProgressBar(prefix='\tprogress:')
312
            for i,coreg_kwargs in enumerate(list_coreg_kwargs):
313
314
                if self.progress:
                    bar.print_progress((i+1)/len(GDF)*100)
315
                results[i,:] = self._get_spatial_shifts(coreg_kwargs)
316
            # FIXME in some cases the code hangs here ==> x ==> seems to be fixed
317

318
         # merge results with GDF
319
320
321
        records = GeoDataFrame(np.array(results, np.object),
                               columns=['POINT_ID', 'X_WIN_SIZE', 'Y_WIN_SIZE', 'X_SHIFT_PX','Y_SHIFT_PX', 'X_SHIFT_M',
                                        'Y_SHIFT_M', 'ABS_SHIFT', 'ANGLE', 'SSIM_BEFORE', 'SSIM_AFTER',
322
323
                                        'SSIM_IMPROVED', 'RELIABILITY', 'LAST_ERR'])

324
325
326
        GDF = GDF.merge(records, on='POINT_ID', how="inner")
        GDF = GDF.fillna(int(self.outFillVal))

327
328
329
330
331
        # filter tie points according to given filter level
        if self.tieP_filter_level>0:
            TPR                   = Tie_Point_Refiner(GDF[GDF.ABS_SHIFT != self.outFillVal])
            GDF_filt, new_columns = TPR.run_filtering(level=self.tieP_filter_level)
            GDF                   = GDF.merge(GDF_filt[ ['POINT_ID']+new_columns], on='POINT_ID', how="outer")
332

333
334
        GDF = GDF.fillna(int(self.outFillVal))
        self.CoRegPoints_table = GDF
335
336
337
338

        return self.CoRegPoints_table


339
340
341
342
343
344
345
    def dump_CoRegPoints_table(self, path_out=None):
        path_out = path_out if path_out else get_generic_outpath(dir_out=self.dir_out,
            fName_out="CoRegPoints_table_grid%s_ws(%s_%s)__T_%s__R_%s.pkl" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
        if not self.q:
            print('Writing %s ...' % path_out)
        self.CoRegPoints_table.to_pickle(path_out)
346
347


348
349
    def to_GCPList(self):
        # get copy of quality grid without no data
350
        GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.ABS_SHIFT != self.outFillVal, :].copy()
351

352
353
354
        if getattr(GDF,'empty'): # GDF.empty returns AttributeError
            return []
        else:
355
            # exclude all points flagged as outliers
356
357
            if 'OUTLIER' in GDF.columns:
                GDF = GDF[GDF.OUTLIER == False].copy()
358

359
360
361
362
363
364
365
366
            # calculate GCPs
            GDF['X_UTM_new'] = GDF.X_UTM + GDF.X_SHIFT_M
            GDF['Y_UTM_new'] = GDF.Y_UTM + GDF.Y_SHIFT_M
            GDF['GCP']       = GDF.apply(lambda GDF_row:    gdal.GCP(GDF_row.X_UTM_new, GDF_row.Y_UTM_new, 0,
                                                                     GDF_row.X_IM, GDF_row.Y_IM), axis=1)
            self.GCPList = GDF.GCP.tolist()

            if not self.q:
367
                print('Found %s valid tie points.' %len(self.GCPList))
368
369

            return self.GCPList
370
371


372
    def test_if_singleprocessing_equals_multiprocessing_result(self):
373
374
        self.tieP_filter_level=1 # RANSAC filtering always produces different results because it includes random sampling

Daniel Scheffler's avatar
Daniel Scheffler committed
375
        self.CPUs = None
376
        dataframe = self.get_CoRegPoints_table()
377
378
        mp_out    = np.empty_like(dataframe.values)
        mp_out[:] = dataframe.values
Daniel Scheffler's avatar
Daniel Scheffler committed
379
        self.CPUs = 1
380
        dataframe = self.get_CoRegPoints_table()
381
382
383
384
385
386
        sp_out    = np.empty_like(dataframe.values)
        sp_out[:] = dataframe.values

        return np.array_equal(sp_out,mp_out)


387
388
    def _get_line_by_PID(self, PID):
        return self.CoRegPoints_table.loc[PID, :]
389
390


391
    def _get_lines_by_PIDs(self, PIDs):
392
        assert isinstance(PIDs,list)
393
        lines = np.zeros((len(PIDs),self.CoRegPoints_table.shape[1]))
394
        for i,PID in enumerate(PIDs):
395
            lines[i,:] = self.CoRegPoints_table[self.CoRegPoints_table['POINT_ID'] == PID]
396
397
398
        return lines


399
400
401
402
403
404
405
406
407
408
    def to_PointShapefile(self, path_out=None, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
        # type: (str, bool, str)
        """Writes the calculated geometric quality grid to a point shapefile containing
        Geom_Quality_Grid.CoRegPoints_table as attribute table. This shapefile can easily be displayed using GIS software.

        :param path_out:        <str> the output path. If not given, it is automatically defined.
        :param skip_nodata:     <bool> whether to skip all points where no valid match could be found
        :param skip_nodata_col: <str> determines which column of Geom_Quality_Grid.CoRegPoints_table is used to
                                identify points where no valid match could be found
        """
409
        GDF            = self.CoRegPoints_table
410
411
        GDF2pass       = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

412
413
414
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="CoRegPoints_grid%s_ws(%s_%s)__T_%s__R_%s.shp" % (self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1], self.shift.basename, self.ref.basename))
Daniel Scheffler's avatar
CoReg:    
Daniel Scheffler committed
415
416
        if not self.q:
            print('Writing %s ...' %path_out)
417
418
419
        GDF2pass.to_file(path_out)


420
    def _to_PointShapefile(self, skip_nodata=True, skip_nodata_col ='ABS_SHIFT'):
421
        warnings.warn(DeprecationWarning("'_quality_grid_to_PointShapefile' is deprecated." # TODO delete if other method validated
422
                                         " 'quality_grid_to_PointShapefile' is much faster."))
423
        GDF            = self.CoRegPoints_table
424
425
426
427
428
        GDF2pass       = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]
        shapely_points = GDF2pass['geometry'].values.tolist()
        attr_dicts     = [collections.OrderedDict(zip(GDF2pass.columns,GDF2pass.loc[i].values)) for i in GDF2pass.index]


429
        fName_out = "CoRegPoints_grid%s_ws%s.shp" %(self.grid_res, self.COREG_obj.win_size_XY)
430
431
432
433
        path_out = os.path.join(self.dir_out, fName_out)
        IO.write_shp(path_out, shapely_points, prj=self.COREG_obj.shift.prj, attrDict=attr_dicts)


434
435
436
    def to_Raster_using_KrigingOLD(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                                   path_out=None, tilepos=None):
        GDF             = self.CoRegPoints_table
437
438
439
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

        # subset if tilepos is given
440
        rows,cols = tilepos if tilepos else self.shift.shape
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
                                       (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]


        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y)#,backend='C',)

457
458
459
        path_out = path_out if path_out else get_generic_outpath(dir_out=os.path.join(self.dir_out, 'CoRegPoints'),
            fName_out="Kriging__%s__grid%s_ws(%s_%s).tif"  % (attrName, self.grid_res, self.COREG_obj.win_size_XY[0],
                        self.COREG_obj.win_size_XY[1]))
460
461
462
463
464
465
466
467
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


468
469
    def Raster_using_Kriging(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                             fName_out=None, tilepos=None, tilesize=500, mp=None):
470

471
472
473
        mp = False if self.CPUs==1 else True
        self._Kriging_sp(attrName, skip_nodata=skip_nodata, skip_nodata_col=skip_nodata_col,
                         outGridRes=outGridRes, fName_out=fName_out, tilepos=tilepos)
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496

        # if mp:
        #     tilepositions = UTL.get_image_tileborders([tilesize,tilesize],self.tgt_shape)
        #     args_kwargs_dicts=[]
        #     for tp in tilepositions:
        #         kwargs_dict = {'skip_nodata':skip_nodata,'skip_nodata_col':skip_nodata_col,'outGridRes':outGridRes,
        #                        'fName_out':fName_out,'tilepos':tp}
        #         args_kwargs_dicts.append({'args':[attrName],'kwargs':kwargs_dict})
        #     # self.kriged=[]
        #     # for i in args_kwargs_dicts:
        #     #     res = self.Kriging_mp(i)
        #     #     self.kriged.append(res)
        #     #     print(res)
        #
        #     with multiprocessing.Pool() as pool:
        #        self.kriged = pool.map(self.Kriging_mp,args_kwargs_dicts)
        # else:
        #     self.Kriging_sp(attrName,skip_nodata=skip_nodata,skip_nodata_col=skip_nodata_col,
        #                     outGridRes=outGridRes,fName_out=fName_out,tilepos=tilepos)
        res = self.kriged if mp else None
        return res


497
498
499
    def _Kriging_sp(self, attrName, skip_nodata=1, skip_nodata_col='ABS_SHIFT', outGridRes=None,
                    fName_out=None, tilepos=None):
        GDF             = self.CoRegPoints_table
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        GDF2pass        = GDF if not skip_nodata else GDF[GDF[skip_nodata_col]!=self.outFillVal]

#         # subset if tilepos is given
# #        overlap_factor =
#         rows,cols = tilepos if tilepos else self.tgt_shape
#         xvals, yvals = np.sort(GDF2pass['X_IM'].values.flat),np.sort(GDF2pass['Y_IM'].values.flat)
#         cS,cE = UTL.find_nearest(xvals,cols[0],'off',1), UTL.find_nearest(xvals,cols[1],'on',1)
#         rS,rE = UTL.find_nearest(yvals,rows[0],'off',1), UTL.find_nearest(yvals,rows[1],'on',1)
#         # GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cols[0])&(GDF2pass['X_IM']<=cols[1])&
#         #                                (GDF2pass['Y_IM']>=rows[0])&(GDF2pass['Y_IM']<=rows[1])]
#         GDF2pass        = GDF2pass.loc[(GDF2pass['X_IM']>=cS)&(GDF2pass['X_IM']<=cE)&
#                                        (GDF2pass['Y_IM']>=rS)&(GDF2pass['Y_IM']<=rE)]

        X_coords,Y_coords,ABS_SHIFT = GDF2pass['X_UTM'], GDF2pass['Y_UTM'],GDF2pass[attrName]

        xmin,ymin,xmax,ymax = GDF2pass.total_bounds

        grid_res            = outGridRes if outGridRes else int(min(xmax-xmin,ymax-ymin)/250)
        grid_x,grid_y       = np.arange(xmin, xmax+grid_res, grid_res), np.arange(ymax, ymin-grid_res, -grid_res)

        # Reference: P.K. Kitanidis, Introduction to Geostatistcs: Applications in Hydrogeology,
        #            (Cambridge University Press, 1997) 272 p.
        OK = OrdinaryKriging(X_coords, Y_coords, ABS_SHIFT, variogram_model='spherical',verbose=False)
        zvalues, sigmasq = OK.execute('grid', grid_x, grid_y,backend='C',n_closest_points=12)

Daniel Scheffler's avatar
Daniel Scheffler committed
525
        if self.CPUs is None or self.CPUs>1:
526
            fName_out = fName_out if fName_out else \
527
                "Kriging__%s__grid%s_ws%s_%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY,tilepos)
528
529
        else:
            fName_out = fName_out if fName_out else \
530
531
                "Kriging__%s__grid%s_ws%s.tif" %(attrName,self.grid_res, self.COREG_obj.win_size_XY)
        path_out  = get_generic_outpath(dir_out=self.dir_out, fName_out=fName_out)
532
533
534
535
536
537
538
539
        print('Writing %s ...' %path_out)
        # add a half pixel grid points are centered on the output pixels
        xmin,ymin,xmax,ymax = xmin-grid_res/2,ymin-grid_res/2,xmax+grid_res/2,ymax+grid_res/2
        IO.write_numpy_to_image(zvalues, path_out, gt=(xmin, grid_res, 0, ymax, 0, -grid_res), prj=self.COREG_obj.shift.prj)

        return zvalues


540
    def _Kriging_mp(self, args_kwargs_dict):
541
542
543
        args   = args_kwargs_dict.get('args',[])
        kwargs = args_kwargs_dict.get('kwargs',[])

544
        return self._Kriging_sp(*args, **kwargs)
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620



class Tie_Point_Refiner(object):
    def __init__(self, GDF, q=False):
        self.GDF        = GDF.copy()
        self.q          = q

        self.new_cols            = []
        self.ransac_model_robust = None


    def run_filtering(self, level=2):
        # TODO catch empty GDF

        if level>0:
            marked_recs            = GeoSeries(self._reliability_thresholding())
            self.GDF['L1_OUTLIER'] = marked_recs
            self.new_cols.append('L1_OUTLIER')
            if not self.q:
                print('%s tie points flagged by level 1 filtering (reliability).' % (len(marked_recs[marked_recs==True])))

        if level>1:
            marked_recs            = GeoSeries(self._SSIM_filtering())
            self.GDF['L2_OUTLIER'] = marked_recs
            self.new_cols.append('L2_OUTLIER')
            if not self.q:
                print('%s tie points flagged by level 2 filtering (SSIM).' % (len(marked_recs[marked_recs==True])))

        if level>2:
            warnings.warn(
                "The currently implemented RANSAC outlier detection is still very experimental. You enabled it "
                "by passing 'tieP_filter_level=2' to COREG_LOCAL. Use it on your own risk!")

            marked_recs            = GeoSeries(self._RANSAC_outlier_detection())
            self.GDF['L3_OUTLIER'] = marked_recs
            self.new_cols.append('L3_OUTLIER')
            if not self.q:
                print('%s tie points flagged by level 3 filtering (RANSAC)' % (len(marked_recs[marked_recs==True])))

        self.GDF['OUTLIER'] = self.GDF[self.new_cols].any(axis=1)
        self.new_cols.append('OUTLIER')

        return self.GDF, self.new_cols


    def _reliability_thresholding(self, min_reliability=30):
        """Exclude all records where estimated reliability of the calculated shifts is below the given threshold.

        :param min_reliability:
        :return:
        """
        return self.GDF.RELIABILITY < min_reliability


    def _SSIM_filtering(self):
        """Exclude all records where SSIM decreased.

        :return:
        """
        return self.GDF.SSIM_IMPROVED == False


    def _RANSAC_outlier_detection(self, max_outlier_percentage=10, tolerance=2.5, max_iter=15,
                                  exclude_previous_outliers=True, timeout=20):
        """Detect geometric outliers between point cloud of source and estimated coordinates using RANSAC algorithm.

        :param max_outlier_percentage:      <float, int> maximum percentage of outliers to be detected
        :param tolerance:                   <float, int> percentage tolerance for max_outlier_percentage
        :param max_iter:                    <int> maximum iterations for finding the best RANSAC threshold
        :param exclude_previous_outliers:   <bool> whether to exclude points that have been flagged as outlier by
                                            earlier filtering
        :param timeout:                     <float, int> timeout for iteration loop in seconds
        :return:
        """

621
        GDF = self.GDF[self.GDF[self.new_cols].any(axis=1)==False].copy() if exclude_previous_outliers else self.GDF
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641

        src_coords = np.array(GDF[['X_UTM', 'Y_UTM']])
        xyShift    = np.array(GDF[['X_SHIFT_M', 'Y_SHIFT_M']])
        est_coords = src_coords + xyShift

        for co, n in zip([src_coords, est_coords], ['src_coords', 'est_coords']):
            assert co.ndim==2 and co.shape[1]==2, "'%s' must have shape [Nx2]. Got shape %s."%(n, co.shape)

        if max_outlier_percentage >100: raise ValueError
        min_inlier_percentage = 100-max_outlier_percentage

        class PolyTF_1(PolynomialTransform):
            def estimate(*data):
                return PolynomialTransform.estimate(*data, order=1)

        # robustly estimate affine transform model with RANSAC
        # exliminates not more than the given maximum outlier percentage of the tie points

        model_robust, inliers = None, None
        count_inliers         = None
642
        th                    = 5  # start RANSAC threshold
643
644
645
646
        th_checked            = {} # dict of thresholds that already have been tried + calculated inlier percentage
        th_substract          = 2
        count_iter            = 0
        time_start            = time.time()
647
        ideal_count           = min_inlier_percentage * src_coords.shape[0] / 100
648
649
650

        while True:
            if th_checked:
651
652
653
654
655
656
657
658
659
660
661
662
                th_too_strict = count_inliers < ideal_count # True if too less inliers remaining

                # calculate new theshold using old increment
                th_new        = th+th_substract if th_too_strict else th-th_substract

                # check if calculated new threshold has been used before
                th_already_checked = th_new in th_checked.keys()

                # if yes, decrease increment and recalculate new threshold
                th_substract       = th_substract if not th_already_checked else th_substract / 2
                th                 = th_new if not th_already_checked else \
                                        (th+th_substract if th_too_strict else th-th_substract)
663
664

            # model_robust, inliers = ransac((src, dst), PolynomialTransform, min_samples=3,
665
666
667
668
669
670
671
672
            model_robust, inliers = \
                ransac((src_coords, est_coords), AffineTransform,
                       min_samples        = 6,
                       residual_threshold = th,
                       max_trials         = 2000,
                       stop_sample_num    = int((min_inlier_percentage-tolerance) /100*src_coords.shape[0]),
                       stop_residuals_sum = int((max_outlier_percentage-tolerance)/100*src_coords.shape[0])
                       )
673
            count_inliers  = np.count_nonzero(inliers)
674

675
            th_checked[th] = count_inliers / src_coords.shape[0] * 100
676
677
678
679
680
681
682
            #print(th,'\t', th_checked[th], )
            if min_inlier_percentage-tolerance < th_checked[th] < min_inlier_percentage+tolerance:
                #print('in tolerance')
                break
            if count_iter > max_iter or time.time()-time_start > timeout:
                break # keep last values and break while loop

683
684
            count_iter+=1

685

686
687
        outliers = inliers == False

688
689
690
691
692
693
694
695
696
697
        if len(GDF) < len(self.GDF):
            GDF['outliers'] = outliers
            fullGDF         = GeoDataFrame(self.GDF['POINT_ID'].copy())
            fullGDF         = fullGDF.merge(GDF[['POINT_ID', 'outliers']], on='POINT_ID', how="outer")
            gs              = fullGDF['outliers']
        else:
            gs              = GeoSeries(outliers)

        assert len(gs)==len(self.GDF), 'RANSAC output validation failed.'

698
        self.ransac_model_robust = model_robust
699
700

        return gs
701