io.py 9.56 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
4
5
6
7
import ctypes
import multiprocessing
import os
import time
import numpy as np
8
9
import ogr
import osr
10
11
12
13
14

try:
    import gdal
except ImportError:
    from osgeo import gdal
15
from spectral.io import envi
16

17
# internal modules
18
from .utilities import get_image_tileborders, convertGdalNumpyDataType
19
20
21
from py_tools_ds.geo.map_info import geotransform2mapinfo
from py_tools_ds.geo.projection import EPSG2WKT
from py_tools_ds.dtypes.conversion import get_dtypeStr
22
23


24
def wait_if_used(path_file, lockfile, timeout=100, try_kill=0):
25
    globs = globals()
26
27
    same_gdalRefs = [k for k, v in globs.items() if
                     isinstance(globs[k], gdal.Dataset) and globs[k].GetDescription() == path_file]
28
    t0 = time.time()
29
30
31

    def update_same_gdalRefs(sRs): return [sR for sR in sRs if sR in globals() and globals()[sR] is not None]

32
    while same_gdalRefs != [] or os.path.exists(lockfile):
33
34
35
36
        if os.path.exists(lockfile):
            continue

        if time.time() - t0 > timeout:
37
38
39
            if try_kill:
                for sR in same_gdalRefs:
                    globals()[sR] = None
40
                    print('had to kill %s' % sR)
41
            else:
42
43
44
45
46
                if os.path.exists(lockfile):
                    os.remove(lockfile)

                raise TimeoutError('The file %s is permanently used by another variable.' % path_file)

47
48
49
        same_gdalRefs = update_same_gdalRefs(same_gdalRefs)


50
51
52
53
54
55
56
57
def write_envi(arr, outpath, gt=None, prj=None):
    if gt or prj:
        assert gt and prj, 'gt and prj must be provided together or left out.'

    meta = {'map info': geotransform2mapinfo(gt, prj), 'coordinate system string': prj} if gt else None
    shape = (arr.shape[0], arr.shape[1], 1) if len(arr.shape) == 3 else arr.shape
    out = envi.create_image(outpath, metadata=meta, shape=shape, dtype=arr.dtype, interleave='bsq', ext='.bsq',
                            force=True)  # 8bit for multiple masks in one file
58
    out_mm = out.open_memmap(writable=True)
59
    out_mm[:, :, 0] = arr
60
61


62
def wfa(p, c):  # pragma: no cover
63
    try:
64
65
        with open(p, 'a') as of:
            of.write(c)
66
    except Exception:
67
        pass
68
69
70


shared_array = None
71
72


73
74
75
def init_SharedArray_in_globals(dims):
    rows, cols = dims
    global shared_array
76
77
78
    shared_array_base = multiprocessing.Array(ctypes.c_double, rows * cols)
    shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
    shared_array = shared_array.reshape(rows, cols)
79
80


81
82
def gdal_read_subset(fPath, pos, bandNr):
    (rS, rE), (cS, cE) = pos
83
    ds = gdal.Open(fPath)
84
85
    data = ds.GetRasterBand(bandNr).ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
    del ds
86
87
88
89
    return data


def fill_arr(argDict, def_param=shared_array):
90
91
92
93
    pos = argDict.get('pos')
    func = argDict.get('func2call')
    args = argDict.get('func_args', [])
    kwargs = argDict.get('func_kwargs', {})
94

95
96
    (rS, rE), (cS, cE) = pos
    shared_array[rS:rE + 1, cS:cE + 1] = func(*args, **kwargs)
97
98


99
def gdal_ReadAsArray_mp(fPath, bandNr, tilesize=1500):
100
    ds = gdal.Open(fPath)
101
102
    rows, cols = ds.RasterYSize, ds.RasterXSize
    del ds
103

104
    init_SharedArray_in_globals((rows, cols))
105

106
107
108
    tilepos = get_image_tileborders([tilesize, tilesize], (rows, cols))
    fill_arr_argDicts = [{'pos': pos, 'func2call': gdal_read_subset, 'func_args': (fPath, pos, bandNr)} for pos in
                         tilepos]
109

110
111
    with multiprocessing.Pool() as pool:
        pool.map(fill_arr, fill_arr_argDicts)
112

113
114
115
    return shared_array


116
117
118
def write_shp(path_out, shapely_geom, prj=None, attrDict=None):
    shapely_geom = [shapely_geom] if not isinstance(shapely_geom, list) else shapely_geom
    attrDict = [attrDict] if not isinstance(attrDict, list) else attrDict
119
120
121
    # print(len(shapely_geom))
    # print(len(attrDict))
    assert len(shapely_geom) == len(attrDict), "'shapely_geom' and 'attrDict' must have the same length."
122
    assert os.path.exists(os.path.dirname(path_out)), 'Directory %s does not exist.' % os.path.dirname(path_out)
123

124
125
126
    print('Writing %s ...' % path_out)
    if os.path.exists(path_out):
        os.remove(path_out)
127
128
129
    ds = ogr.GetDriverByName("Esri Shapefile").CreateDataSource(path_out)

    if prj is not None:
130
        prj = prj if not isinstance(prj, int) else EPSG2WKT(prj)
131
132
133
134
135
136
        srs = osr.SpatialReference()
        srs.ImportFromWkt(prj)
    else:
        srs = None

    geom_type = list(set([gm.type for gm in shapely_geom]))
137
    assert len(geom_type) == 1, 'All shapely geometries must belong to the same type. Got %s.' % geom_type
138

139
140
141
142
    layer = \
        ds.CreateLayer('', srs, ogr.wkbPoint) if geom_type[0] == 'Point' else\
        ds.CreateLayer('', srs, ogr.wkbLineString) if geom_type[0] == 'LineString' else \
        ds.CreateLayer('', srs, ogr.wkbPolygon) if geom_type[0] == 'Polygon' else None  # FIXME
143

144
    if isinstance(attrDict[0], dict):
145
        for attr in attrDict[0].keys():
146
147
148
149
150
151
152
            assert len(attr) <= 10, "ogr does not support fieldnames longer than 10 digits. '%s' is too long" % attr
            DTypeStr = get_dtypeStr(attrDict[0][attr])
            FieldType = \
                ogr.OFTInteger if DTypeStr.startswith('int') else \
                ogr.OFTReal if DTypeStr.startswith('float') else \
                ogr.OFTString if DTypeStr.startswith('str') else \
                ogr.OFTDateTime if DTypeStr.startswith('date') else None
153
154
155
156
157
158
159
160
            FieldDefn = ogr.FieldDefn(attr, FieldType)
            if DTypeStr.startswith('float'):
                FieldDefn.SetPrecision(6)
            layer.CreateField(FieldDefn)  # Add one attribute

    for i in range(len(shapely_geom)):
        # Create a new feature (attribute and geometry)
        feat = ogr.Feature(layer.GetLayerDefn())
161
        feat.SetGeometry(ogr.CreateGeometryFromWkb(shapely_geom[i].wkb))  # Make a geometry, from Shapely object
162

163
        list_attr2set = attrDict[0].keys() if isinstance(attrDict[0], dict) else []
164
165

        for attr in list_attr2set:
166
167
            val = attrDict[i][attr]
            DTypeStr = get_dtypeStr(val)
168
            val = int(val) if DTypeStr.startswith('int') else float(val) if DTypeStr.startswith('float') else \
169
                str(val) if DTypeStr.startswith('str') else val
170
171
172
173
174
175
            feat.SetField(attr, val)

        layer.CreateFeature(feat)
        feat.Destroy()

    # Save and close everything
176
    del ds, layer
177
178


179
180
181
182
def write_numpy_to_image(array, path_out, outFmt='GTIFF', gt=None, prj=None):
    rows, cols, bands = list(array.shape) + [1] if len(array.shape) == 2 else array.shape
    gdal_dtype = gdal.GetDataTypeByName(convertGdalNumpyDataType(array.dtype))
    outDs = gdal.GetDriverByName(outFmt).Create(path_out, cols, rows, bands, gdal_dtype)
183
    for b in range(bands):
184
185
        band = outDs.GetRasterBand(b + 1)
        arr2write = array if len(array.shape) == 2 else array[:, :, b]
186
        band.WriteArray(arr2write)
187
188
189
190
191
192
        del band
    if gt:
        outDs.SetGeoTransform(gt)
    if prj:
        outDs.SetProjection(prj)
    del outDs
193
194


195
# def get_tempfile(ext=None,prefix=None,tgt_dir=None):
196
197
198
199
200
201
202
203
204
205
206
207
208
#    """Returns the path to a tempfile.mkstemp() file that can be passed to any function that expects a physical path.
#    The tempfile has to be deleted manually.
#    :param ext:     file extension (None if None)
#    :param prefix:  optional file prefix
#    :param tgt_dir: target directory (automatically set if None)
#     """
#    prefix   = 'danschef__CoReg__' if prefix is None else prefix
#    fd, path = tempfile.mkstemp(prefix=prefix,suffix=ext,dir=tgt_dir)
#    os.close(fd)
#    return path


shared_array_on_disk__memmap = None
209
210
211


def init_SharedArray_on_disk(out_path, dims, gt=None, prj=None):
212
213
    global shared_array_on_disk__memmap
    global shared_array_on_disk__path
214
215
216
    path = out_path if not os.path.splitext(out_path)[1] == '.bsq' else \
        os.path.splitext(out_path)[0] + '.hdr'
    Meta = {}
217
    if gt and prj:
218
        Meta['map info'] = geotransform2mapinfo(gt, prj)
219
        Meta['coordinate system string'] = prj
220
221
    shared_array_on_disk__obj = envi.create_image(path, metadata=Meta, shape=dims, dtype='uint16',
                                                  interleave='bsq', ext='.bsq', force=True)
222
223
224
    shared_array_on_disk__memmap = shared_array_on_disk__obj.open_memmap(writable=True)


225
def fill_arr_on_disk(argDict):
226
227
228
    pos = argDict.get('pos')
    in_path = argDict.get('in_path')
    band = argDict.get('band')
229

230
231
    (rS, rE), (cS, cE) = pos
    ds = gdal.Open(in_path)
232
    band = ds.GetRasterBand(band)
233
234
235
    data = band.ReadAsArray(cS, rS, cE - cS + 1, rE - rS + 1)
    shared_array_on_disk__memmap[rS:rE + 1, cS:cE + 1, 0] = data
    del ds, band
236
237


238
def convert_gdal_to_bsq__mp(in_path, out_path, band=1):
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    """
    Usage:
        ref_ds,tgt_ds = gdal.Open(self.path_imref),gdal.Open(self.path_im2shift)
        ref_pathTmp, tgt_pathTmp = None,None
        if ref_ds.GetDriver().ShortName!='ENVI':
            ref_pathTmp = IO.get_tempfile(ext='.bsq')
            IO.convert_gdal_to_bsq__mp(self.path_imref,ref_pathTmp)
            self.path_imref = ref_pathTmp
        if tgt_ds.GetDriver().ShortName!='ENVI':
            tgt_pathTmp = IO.get_tempfile(ext='.bsq')
            IO.convert_gdal_to_bsq__mp(self.path_im2shift,tgt_pathTmp)
            self.path_im2shift = tgt_pathTmp
        ref_ds=tgt_ds=None

    :param in_path:
    :param out_path:
    :param band:
    :return:
    """

259
260
261
262
263
264
265
    ds = gdal.Open(in_path)
    dims = (ds.RasterYSize, ds.RasterXSize)
    gt, prj = ds.GetGeoTransform(), ds.GetProjection()
    del ds
    init_SharedArray_on_disk(out_path, dims, gt, prj)
    positions = get_image_tileborders([512, 512], dims)
    argDicts = [{'pos': pos, 'in_path': in_path, 'band': band} for pos in positions]
266

267
268
    with multiprocessing.Pool() as pool:
        pool.map(fill_arr_on_disk, argDicts)