DeShifter.py 19 KB
Newer Older
1
2
3
4
5
6
7
8
# -*- coding: utf-8 -*-
__author__='Daniel Scheffler'

import collections
import time
import warnings

# custom
9
10
11
12
try:
    import gdal
except ImportError:
    from osgeo import gdal
13
14
15
16
17
18
19
20
21

# internal modules
from py_tools_ds.ptds                      import GeoArray
from py_tools_ds.ptds.geo.map_info         import mapinfo2geotransform, geotransform2mapinfo
from py_tools_ds.ptds.geo.coord_grid       import is_coord_grid_equal
from py_tools_ds.ptds.geo.projection       import prj_equal
from py_tools_ds.ptds.geo.raster.reproject import warp_ndarray
from py_tools_ds.ptds.numeric.vector       import find_nearest

22
23
_dict_rspAlg_rsp_Int = {'nearest': 0, 'bilinear': 1, 'cubic': 2, 'cubic_spline': 3, 'lanczos': 4, 'average': 5,
                        'mode': 6, 'max': 7, 'min': 8 , 'med': 9, 'q1':10, 'q2':11}
24
25

class DESHIFTER(object):
26
    """See help(DESHIFTER) for documentation!"""
27
28
29
30
    def __init__(self, im2shift, coreg_results, **kwargs):
        """
        Deshift an image array or one of its products by applying the coregistration info calculated by COREG class.

31
        :param im2shift:            <path,GeoArray> path of an image to be de-shifted or alternatively a GeoArray object
32
        :param coreg_results:       <dict> the results of the co-registration as given by COREG.coreg_info or
33
                                    COREG_LOCAL.coreg_info respectively
34
35
36

        :Keyword Arguments:
            - path_out(str):        /output/directory/filename for coregistered results
37
38
            - fmt_out (str):        raster file format for output file. ignored if path_out is None. can be any GDAL
                                        compatible raster file format (e.g. 'ENVI', 'GeoTIFF'; default: ENVI)
39
40
            - band2process (int):   The index of the band to be processed within the given array (starts with 1),
                                    default = None (all bands are processed)
41
            - nodata(int, float):   no data value of an image to be de-shifted
42
43
44
45
46
47
48
49
50
            - out_gsd (float):      output pixel size in units of the reference coordinate system (default = pixel size
                                    of the input array), given values are overridden by match_gsd=True
            - align_grids (bool):   True: align the input coordinate grid to the reference (does not affect the
                                    output pixel size as long as input and output pixel sizes are compatible
                                    (5:30 or 10:30 but not 4:30), default = False
            - match_gsd (bool):     True: match the input pixel size to the reference pixel size,
                                    default = False
            - target_xyGrid(list):  a list with an x-grid and a y-grid like [[15,45], [15,45]]
            - resamp_alg(str)       the resampling algorithm to be used if neccessary
51
52
                                    (valid algorithms: nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
                                                       max, min, med, q1, q3)
53
54
            - cliptoextent (bool):  True: clip the input image to its actual bounds while deleting possible no data
                                    areas outside of the actual bounds, default = True
55
56
            - clipextent (list):    xmin, ymin, xmax, ymax - if given the calculation of the actual bounds is skipped.
                                    The given coordinates are automatically snapped to the output grid.
57
            - CPUs(int):            number of CPUs to use (default: None, which means 'all CPUs available')
58
            - progress(bool):       show progress bars (default: True)
59
60
            - v(bool):              verbose mode (default: False)
            - q(bool):              quiet mode (default: False)
61
62
63
64

        """
        # unpack args
        self.im2shift           = im2shift if isinstance(im2shift, GeoArray) else GeoArray(im2shift)
65
66
67
        self.shift_prj          = self.im2shift.projection
        self.shift_gt           = list(self.im2shift.geotransform)
        self.GCPList            = coreg_results['GCPList'] if 'GCPList' in coreg_results else None
68
69
70
71
72
        if not self.GCPList:
            mapI                    = coreg_results['updated map info']
            self.updated_map_info   = mapI if mapI else geotransform2mapinfo(self.shift_gt, self.shift_prj)
            self.original_map_info  = coreg_results['original map info']
            self.updated_gt         = mapinfo2geotransform(self.updated_map_info) if mapI else self.shift_gt
73
74
75
76
77
78
79
        self.ref_gt             = coreg_results['reference geotransform']
        self.ref_grid           = coreg_results['reference grid']
        self.ref_prj            = coreg_results['reference projection']
        self.updated_projection = self.ref_prj

        # unpack kwargs
        self.path_out     = kwargs.get('path_out'    , None)
80
        self.fmt_out      = kwargs.get('fmt_out'     , 'ENVI')
81
        self.band2process = kwargs.get('band2process', None) # starts with 1 # FIXME warum?
82
        self.nodata       = kwargs.get('nodata'      , self.im2shift.nodata)
83
84
85
86
        self.align_grids  = kwargs.get('align_grids' , False)
        self.rspAlg       = kwargs.get('resamp_alg'  , 'cubic')
        self.cliptoextent = kwargs.get('cliptoextent', True)
        self.clipextent   = kwargs.get('clipextent'  , None)
87
        self.CPUs         = kwargs.get('CPUs'        , None)
88
        self.v            = kwargs.get('v'           , False)
89
90
        self.q            = kwargs.get('q'           , False) if not self.v else False # overridden by v
        self.progress     = kwargs.get('progress'    , True)  if not self.q else False # overridden by q
91
        self.out_grid     = self._get_out_grid(kwargs) # needs self.ref_grid, self.im2shift
92
93
94
        self.out_gsd      = [abs(self.out_grid[0][1]-self.out_grid[0][0]), abs(self.out_grid[1][1]-self.out_grid[1][0])]  # xgsd, ygsd

        # assertions
95
        assert self.rspAlg  in _dict_rspAlg_rsp_Int.keys(), \
96
            "'%s' is not a supported resampling algorithm." %self.rspAlg
97
98
99
100
101
102

        # set defaults for general class attributes
        self.is_shifted       = False # this is not included in COREG.coreg_info
        self.is_resampled     = False # this is not included in COREG.coreg_info
        self.tracked_errors   = []
        self.arr_shifted      = None  # set by self.correct_shifts
103
        self.GeoArray_shifted = None  # set by self.correct_shifts
104
105


106
    def _get_out_grid(self, init_kwargs):
107
108
109
110
111
112
        # parse given params
        out_gsd     = init_kwargs.get('out_gsd'      , None)
        match_gsd   = init_kwargs.get('match_gsd'    , False)
        out_grid    = init_kwargs.get('target_xyGrid', None)

        # assertions
113
114
        assert out_grid is None or (isinstance(out_grid,(list, tuple))      and len(out_grid)==2)
        assert out_gsd  is None or (isinstance(out_gsd, (int, tuple, list)) and len(out_gsd) ==2)
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

        ref_xgsd, ref_ygsd = (self.ref_grid[0][1]-self.ref_grid[0][0],self.ref_grid[1][1]-self.ref_grid[1][0])
        get_grid           = lambda gt, xgsd, ygsd: [[gt[0], gt[0] + xgsd], [gt[3], gt[3] - ygsd]]

        # get out_grid
        if out_grid:
            # output grid is given
            return out_grid

        elif out_gsd:
            out_xgsd, out_ygsd = [out_gsd, out_gsd] if isinstance(out_gsd, int) else out_gsd

            if match_gsd and (out_xgsd, out_ygsd)!=(ref_xgsd, ref_ygsd):
                warnings.warn("\nThe parameter 'match_gsd is ignored because another output ground sampling distance "
                              "was explicitly given.")
130
            if self.align_grids and self._grids_alignable(self.im2shift.xgsd, self.im2shift.ygsd, out_xgsd, out_ygsd):
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
                # use grid of reference image with the given output gsd
                return get_grid(self.ref_gt, out_xgsd, out_ygsd)
            else: # no grid alignment
                # use grid of input image with the given output gsd
                return get_grid(self.im2shift.geotransform, out_xgsd, out_ygsd)

        elif match_gsd:
            if self.align_grids:
                # use reference grid
                return self.ref_grid
            else:
                # use grid of input image and reference gsd
                return get_grid(self.im2shift.geotransform, ref_xgsd, ref_ygsd)

        else:
146
            if self.align_grids and self._grids_alignable(self.im2shift.xgsd, self.im2shift.ygsd, ref_xgsd, ref_ygsd):
147
148
149
150
151
152
153
154
                # use origin of reference image and gsd of input image
                return get_grid(self.ref_gt, self.im2shift.xgsd, self.im2shift.ygsd)
            else:
                # use input image grid
                return get_grid(self.im2shift.geotransform, self.im2shift.xgsd, self.im2shift.ygsd)


    @staticmethod
155
    def _grids_alignable(in_xgsd, in_ygsd, out_xgsd, out_ygsd):
156
157
158
159
160
161
162
163
164
165
166
167
168
        is_alignable = lambda gsd1, gsd2: max(gsd1, gsd2) % min(gsd1, gsd2) == 0  # checks if pixel sizes are divisible
        if not is_alignable(in_xgsd, out_xgsd) or not is_alignable(in_ygsd, out_ygsd):
            warnings.warn("\nThe targeted output coordinate grid is not alignable with the image to be shifted because "
                          "their pixel sizes are not exact multiples of each other (input [X/Y]: "
                          "%s %s; output [X/Y]: %s %s). Therefore the targeted output grid is "
                          "chosen for the resampled output image. If you don´t like that you can use the '-out_gsd' "
                          "parameter to set an appropriate output pixel size.\n"
                          % (in_xgsd, in_ygsd, out_xgsd, out_ygsd))
            return False
        else:
            return True


169
    def _get_out_extent(self):
170
        if self.cliptoextent and self.clipextent is None:
171
            self.clipextent        = self.im2shift.footprint_poly.bounds
172
        else:
173
            xmin, xmax, ymin, ymax = self.im2shift.box.boundsMap
174
175
            self.clipextent        = xmin, ymin, xmax, ymax

176
177
178
179
180
181
182
183
184
185
186
187
188

        # snap clipextent to output grid (in case of odd input coords the output coords are moved INSIDE the input array)
        xmin, ymin, xmax, ymax = self.clipextent
        xmin = find_nearest(self.out_grid[0], xmin, roundAlg='on' , extrapolate=True)
        ymin = find_nearest(self.out_grid[1], ymin, roundAlg='on' , extrapolate=True)
        xmax = find_nearest(self.out_grid[0], xmax, roundAlg='off', extrapolate=True)
        ymax = find_nearest(self.out_grid[0], ymax, roundAlg='off', extrapolate=True)
        return xmin, ymin, xmax, ymax


    def correct_shifts(self):
        # type: (DESHIFTER) -> collections.OrderedDict

189
190
191
        if not self.q:
            print('Correcting geometric shifts...')

192
193
194
195
        t_start   = time.time()
        equal_prj = prj_equal(self.ref_prj,self.shift_prj)

        if equal_prj and is_coord_grid_equal(self.shift_gt, *self.out_grid) and not self.align_grids:
196
197
198
199
200
            # FIXME buggy condition:
            # reconstructable with correct_spatial_shifts from GMS
            #DS = DESHIFTER(geoArr, self.coreg_info,
            #               target_xyGrid=[usecase.spatial_ref_gridx, usecase.spatial_ref_gridy],
            #               cliptoextent=True, clipextent=mapBounds, align_grids=False) => align grids False
201
202
203
            """NO RESAMPLING NEEDED"""
            self.is_shifted     = True
            self.is_resampled   = False
204
            xmin,ymin,xmax,ymax = self._get_out_extent()
205
206
207
208
209
210
211
212
213
214
215
216

            if self.cliptoextent: # TODO validate results!
                # get shifted array
                shifted_geoArr = GeoArray(self.im2shift[:],tuple(self.updated_gt), self.shift_prj)

                # clip with target extent
                self.arr_shifted, self.updated_gt, self.updated_projection = \
                        shifted_geoArr.get_mapPos((xmin,ymin,xmax,ymax), self.shift_prj, fillVal=self.nodata)
                self.updated_map_info = geotransform2mapinfo(self.updated_gt, self.updated_projection)
            else:
                # array keeps the same; updated gt and prj are taken from coreg_info
                self.arr_shifted = self.im2shift[:]
217
            self.GeoArray_shifted = GeoArray(self.arr_shifted, tuple(self.shift_gt), self.updated_projection)
218
219

            if self.path_out:
220
                GeoArray(self.arr_shifted,self.updated_gt,self.updated_projection).save(self.path_out,fmt=self.fmt_out)
221
222
223

        else: # FIXME equal_prj==False ist noch NICHT implementiert
            """RESAMPLING NEEDED"""
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
            # if self.warpAlg=='GDAL_cmd':
            #     warnings.warn('This method has not been tested in its current state!')
            #     # FIX ME nicht multiprocessing-fähig, weil immer kompletter array gewarpt wird und sich ergebnisse gegenseitig überschreiben
            #     # create tempfile
            #     fd, path_tmp = tempfile.mkstemp(prefix='CoReg_Sat', suffix=self.outFmt, dir=self.tempDir)
            #     os.close(fd)
            #
            #     t_extent   = " -te %s %s %s %s" %self._get_out_extent()
            #     xgsd, ygsd = self.out_gsd
            #     cmd = "gdalwarp -r %s -tr %s %s -t_srs '%s' -of %s %s %s -srcnodata %s -dstnodata %s -overwrite%s"\
            #           %(self.rspAlg, xgsd,ygsd,self.ref_prj,self.outFmt,self.im2shift.filePath,
            #             path_tmp, self.nodata, self.nodata, t_extent)
            #     out, exitcode, err = subcall_with_output(cmd)
            #
            #     if exitcode!=1 and os.path.exists(path_tmp):
            #         """update map info, arr_shifted, geotransform and projection"""
            #         ds_shifted = gdal.OpenShared(path_tmp) if self.outFmt == 'VRT' else gdal.Open(path_tmp)
            #         self.shift_gt, self.shift_prj = ds_shifted.GetGeoTransform(), ds_shifted.GetProjection()
            #         self.updated_map_info         = geotransform2mapinfo(self.shift_gt,self.shift_prj)
            #
            #         print('reading from', ds_shifted.GetDescription())
            #         if self.band2process is None:
            #             dim2RowsColsBands = lambda A: np.swapaxes(np.swapaxes(A,0,2),0,1) # rasterio.open(): [bands,rows,cols]
            #             self.arr_shifted  = dim2RowsColsBands(rasterio.open(path_tmp).read())
            #         else:
            #             self.arr_shifted  = rasterio.open(path_tmp).read(self.band2process)
            #
            #         self.GeoArray_shifted = GeoArray(self.arr_shifted,tuple(self.shift_gt), self.shift_prj)
            #         self.is_shifted       = True
            #         self.is_resampled     = True
            #
            #         ds_shifted            = None
            #         [gdal.Unlink(p) for p in [path_tmp] if os.path.exists(p)] # delete tempfiles
            #     else:
            #         print("\n%s\nCommand was:  '%s'" %(err.decode('utf8'),cmd))
            #         [gdal.Unlink(p) for p in [path_tmp] if os.path.exists(p)] # delete tempfiles
            #         self.tracked_errors.append(RuntimeError('Resampling failed.'))
            #         raise self.tracked_errors[-1]
            #
            #     # TO DO implement output writer

            in_arr = self.im2shift[self.band2process] if self.band2process else self.im2shift[:]

            if not self.GCPList:
268
                # apply XY-shifts to shift_gt
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                self.shift_gt[0], self.shift_gt[3] = self.updated_gt[0], self.updated_gt[3]

            # get resampled array
            out_arr, out_gt, out_prj = \
                warp_ndarray(in_arr, self.shift_gt, self.shift_prj, self.ref_prj,
                             rspAlg     = _dict_rspAlg_rsp_Int[self.rspAlg],
                             in_nodata  = self.nodata,
                             out_nodata = self.nodata,
                             out_gsd    = self.out_gsd,
                             out_bounds = self._get_out_extent(),
                             gcpList    = self.GCPList,
                             #polynomialOrder = str(3),
                             #options    = '-refine_gcps 500 1.9',
                             #warpOptions= ['-refine_gcps 500 1.9'],
                             #options      = '-wm 10000',# -order 3',
                             #options      = ['-order 3'],
#                                 options = ['GDAL_CACHEMAX 800 '],
                             #warpMemoryLimit=125829120, # 120MB
                             CPUs       = self.CPUs,
                             progress   = self.progress,
                             q          = self.q)

            self.updated_projection = out_prj
            self.arr_shifted        = out_arr
            self.updated_map_info   = geotransform2mapinfo(out_gt,out_prj)
            self.shift_gt           = mapinfo2geotransform(self.updated_map_info)
            self.GeoArray_shifted   = GeoArray(self.arr_shifted, tuple(self.shift_gt), self.updated_projection)
            self.is_shifted         = True
            self.is_resampled       = True

            if self.path_out:
                GeoArray(out_arr, out_gt, out_prj).save(self.path_out,fmt=self.fmt_out)
301

302
        if self.v: print('Time for shift correction: %.2fs' %(time.time()-t_start))
303
304
305
306
307
308
        return self.deshift_results


    @property
    def deshift_results(self):
        deshift_results = collections.OrderedDict()
309
310
311
312
313
314
315
316
        deshift_results.update({'band'                : self.band2process})
        deshift_results.update({'is shifted'          : self.is_shifted})
        deshift_results.update({'is resampled'        : self.is_resampled})
        deshift_results.update({'updated map info'    : self.updated_map_info})
        deshift_results.update({'updated geotransform': self.shift_gt})
        deshift_results.update({'updated projection'  : self.updated_projection})
        deshift_results.update({'arr_shifted'         : self.arr_shifted})
        deshift_results.update({'GeoArray_shifted'    : self.GeoArray_shifted})
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        return deshift_results



def deshift_image_using_coreg_info(im2shift, coreg_results, path_out=None, fmt_out='ENVI', q=False):
    """Corrects a geometrically distorted image using previously calculated coregistration info. This function can be
    used for example to corrects spatial shifts of mask files using the same transformation parameters that have been
    used to correct their source images.

    :param im2shift:      <path,GeoArray> path of an image to be de-shifted or alternatively a GeoArray object
    :param coreg_results: <dict> the results of the co-registration as given by COREG.coreg_info or
                          COREG_LOCAL.coreg_info respectively
    :param path_out:      /output/directory/filename for coregistered results. If None, no output is written - only
                          the shift corrected results are returned.
    :param fmt_out:       raster file format for output file. ignored if path_out is None. can be any GDAL
                                        compatible raster file format (e.g. 'ENVI', 'GeoTIFF'; default: ENVI)
    :param q:             quiet mode (default: False)
    :return:
    """
    deshift_results = DESHIFTER(im2shift, coreg_results).correct_shifts()

    if path_out:
        deshift_results['GeoArray_shifted'].save(path_out, fmt_out=fmt_out, q=q)

    return deshift_results