Commit bcd35c46 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Bugfix for footprint_poly; revised progress bar handling; bugfix for not...

Bugfix for footprint_poly; revised progress bar handling; bugfix for not resetting timeout start time; some further developments

compatibility.python.exceptions:
- FileNotFoundError: revised docstring

compatibility.gdal:
- refactored TranslateOptions to Translate

geo.raster.conversion:
- raster2polygon(): updated calls for progress bar and timeout

geo.raster.reproject:
- warp_ndarray(): updated calls for progress bar and timeout

geo.vector.topology:
- added fill_holes_within_poly()

io.raster.gdal:
- get_GDAL_ds_inmem(): implemented keyword 'nodata'

io.raster.GeoArray:
- GeoArray:
    - added many docstrings
    - projection.setter: revised assertion
    - footprint_poly(): bugfix for not consequently returning shapely.geometry.Polygon instances
    - save(): implemented keyword 'creationOptions': allows to pass creation options to GDAL writer

processing.progress_mon:
- replaced function is_timed_out() with new class 'Timer'
- replaced function printProgress() with new class 'ProgressBar'
parent 7f161116
......@@ -194,7 +194,7 @@ def Warp(destNameOrDestDS, srcDSOrSrcDSTab, options = '', format = 'GTiff',
return mem_ds
def TranslateOptions(destNameOrDestDS, srcDSOrSrcDSTab, options = '', format = 'GTiff',
def Translate(destNameOrDestDS, srcDSOrSrcDSTab, options = '', format = 'GTiff',
outputType = gdal.GDT_Unknown, bandList = None, maskBand = None,
width = 0, height = 0, widthPct = 0.0, heightPct = 0.0,
xRes = 0.0, yRes = 0.0,
......
......@@ -2,10 +2,11 @@
__author__ = "Daniel Scheffler"
class TimeoutError(OSError):
""" Timeout expired. """
pass
class FileNotFoundError(OSError):
""" Timeout expired. """
pass
\ No newline at end of file
""" File not found. """
pass
......@@ -15,7 +15,7 @@ except ImportError:
from osgeo import osr
from ...io.raster.gdal import get_GDAL_ds_inmem
from ...processing.progress_mon import printProgress, is_timed_out
from ...processing.progress_mon import ProgressBar, Timer
from ...compatibility.python.exceptions import TimeoutError as TimeoutError_comp
......@@ -54,11 +54,8 @@ def raster2polygon(array_or_GeoArray, gt=None, prj=None, DN2extract=1, exact=Tru
mem_layer.CreateField(fd)
# set callback
progressBar = lambda percent01, message, user_data: \
printProgress(percent01 * 100, prefix='Polygonize progress ', suffix='Complete', barLength=50, timeout=timeout)
timeout_callback = lambda percent01, message, user_data: is_timed_out(timeout)
callback = progressBar if progress and not q else timeout_callback if timeout else None
callback = ProgressBar(prefix='Polygonize progress ', suffix='Complete', barLength=50, timeout=timeout) \
if progress and not q else Timer(timeout) if timeout else None
# run the algorithm
result = gdal.Polygonize(src_band, src_band.GetMaskBand(), mem_layer, 0, ["8CONNECTED=8"] if exact else [],
......@@ -100,8 +97,9 @@ def raster2polygon(array_or_GeoArray, gt=None, prj=None, DN2extract=1, exact=Tru
#GDF['geometry'] = GDF.apply(get_shplyPoly, axis=1)
GDF = GeoDataFrame(columns=['geometry', 'DN'])
timer = Timer(timeout)
for i in range(featCount):
if not is_timed_out(timeout):
if not timer.timed_out:
element = mem_layer.GetNextFeature()
GDF.loc[i] = [loads(element.GetGeometryRef().ExportToWkb()).buffer(0), DN2extract]
element = None
......
......@@ -21,7 +21,7 @@ from rasterio.warp import Resampling
from ..projection import WKT2EPSG, isProjectedOrGeographic, prj_equal
from ..coord_trafo import pixelToLatLon
from ...io.raster.gdal import get_GDAL_ds_inmem
from ...processing.progress_mon import printProgress
from ...processing.progress_mon import ProgressBar
from ...compatibility.gdal import get_gdal_func
......@@ -282,10 +282,6 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=(
'EPSG:%s'%prjArg if isinstance(prjArg,int) else \
prjArg
get_GDT = lambda DT: dTypeDic_NumPy2GDAL[str(np.dtype(DT))]
progressBarTran = (lambda percent01, message, user_data: printProgress(percent01 * 100,
**{'prefix': 'Translating progress', 'suffix': 'Complete', 'barLength': 50})) if progress and not q else None
progressBarWarp = (lambda percent01, message, user_data: printProgress(percent01 * 100,
**{'prefix': 'Warping progress ', 'suffix': 'Complete', 'barLength': 50})) if progress and not q else None
# not yet implemented
cutlineDSName = 'data/cutline.vrt' #'/vsimem/cutline.shp' TODO cutline from OGR datasource. => implement input shapefile or Geopandas dataframe
......@@ -377,7 +373,7 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=(
'', in_ds, format='MEM',
outputSRS = get_SRS(out_prj),
GCPs = gcpList,
callback = progressBarTran
callback = ProgressBar(prefix='Translating progress', timeout=None) if progress and not q else None
)
# NOTE: options = ['SPARSE_OK=YES'] ## => what is that for?
......@@ -405,7 +401,7 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=(
tps = True if gcpList else False,
polynomialOrder = polynomialOrder,
warpMemoryLimit = warpMemoryLimit,
callback = progressBarWarp,
callback = ProgressBar(prefix='Warping progress ', timeout=None) if progress and not q else None,
callback_data = [0],
errorThreshold = 0.125, # this is needed to get exactly the same output like the console version of GDAL warp
)
......
......@@ -106,4 +106,26 @@ def polyVertices_outside_poly(inner_poly, outer_poly):
return False in GDF.apply(lambda GDF_row: Point(GDF_row.X, GDF_row.Y).intersects(outer_poly), axis=1).values
else:
# inner_poly does not intersect out_poly -> all vertices are outside
return True
\ No newline at end of file
return True
def fill_holes_within_poly(poly):
"""Fills the holes within a shapely Polygon or MultiPolygon and returns a Polygon with only the outer boundary.
:param poly: <shapely.geometry.Polygon, shapely.geometry.MultiPolygon>
:return:
"""
if poly.geom_type == 'Polygon':
return poly
elif poly.geom_type == 'MultiPolygon':
gdf = GeoDataFrame(columns=['geometry'])
gdf['geometry'] = poly
# get the area of each polygon of the multipolygon EXCLUDING the gaps in it
gdf['area_filled'] = gdf.apply(
lambda GDF_row: Polygon(np.swapaxes(np.array(GDF_row.geometry.exterior.coords.xy), 0, 1)).area, axis=1)
largest_poly_filled = gdf.ix[gdf['area_filled'].idxmax()]['geometry']
# return the outer boundary of the largest polygon
return Polygon(np.swapaxes(np.array(largest_poly_filled.exterior.coords.xy), 0, 1))
\ No newline at end of file
......@@ -28,12 +28,14 @@ from ...geo.coord_grid import snap_bounds_to_pixGrid
from ...geo.coord_trafo import mapXY2imXY, imXY2mapXY, transform_any_prj, reproject_shapelyGeometry
from ...geo.projection import prj_equal, WKT2EPSG, EPSG2WKT
from ...geo.raster.conversion import raster2polygon
from ...geo.vector.topology import get_overlap_polygon, get_footprint_polygon, polyVertices_outside_poly
from ...geo.vector.topology import get_overlap_polygon, get_footprint_polygon, polyVertices_outside_poly, \
fill_holes_within_poly
from ...geo.vector.geometry import boxObj
from ...io.raster.gdal import get_GDAL_ds_inmem
from ...numeric.array import find_noDataVal, get_outFillZeroSaturated
from ...compatibility.python.exceptions import TimeoutError as TimeoutError_comp, \
FileNotFoundError as FileNotFoundError_comp
from ...compatibility.gdal import get_gdal_func
def _alias_property(key):
......@@ -121,11 +123,13 @@ class GeoArray(object):
@property
def is_inmem(self):
"""Check if associated image array is completely loaded into memory."""
return isinstance(self.arr, np.ndarray)
@property
def shape(self):
"""Get the array shape of the associated image array."""
if self._shape:
return self._shape
else:
......@@ -135,16 +139,19 @@ class GeoArray(object):
@property
def ndim(self):
"""Get the number dimensions of the associated image array."""
return len(self.shape)
@property
def rows(self):
"""Get the number of rows of the associated image array."""
return self.shape[0]
@property
def columns(self):
"""Get the number of columns of the associated image array."""
return self.shape[1]
......@@ -153,11 +160,13 @@ class GeoArray(object):
@property
def bands(self):
"""Get the number of bands of the associated image array."""
return self.shape[2] if len(self.shape)>2 else 1
@property
def dtype(self):
"""Get the numpy data type of the associated image array."""
if self._dtype:
return self._dtype
elif self.is_inmem:
......@@ -169,6 +178,7 @@ class GeoArray(object):
@property
def geotransform(self):
"""Get the GDAL GeoTransform of the associated image."""
if self._geotransform:
return self._geotransform
elif not self.is_inmem:
......@@ -191,16 +201,20 @@ class GeoArray(object):
@property
def xgsd(self):
"""Get the X resolution in units of the given or detected projection."""
return self.geotransform[1]
@property
def ygsd(self):
"""Get the Y resolution in units of the given or detected projection."""
return abs(self.geotransform[5])
@property
def projection(self):
"""Get the projection of the associated image. Setting the projection is only allowed if GeoArray has been
instanced from memory or the associated file on disk has no projection."""
if self._projection:
return self._projection
elif not self.is_inmem:
......@@ -213,7 +227,7 @@ class GeoArray(object):
@projection.setter
def projection(self, prj):
if self.filePath:
assert self.projection == prj, "Cannot set %s.projection to the given value because it does not " \
assert self.projection in [None,prj], "Cannot set %s.projection to the given value because it does not " \
"match the projection from the file on disk." %self.__class__.__name__
else:
self._projection = prj
......@@ -224,6 +238,7 @@ class GeoArray(object):
@property
def epsg(self):
"""Get the EPSG code of the projection of the GeoArray."""
return WKT2EPSG(self.projection)
......@@ -240,6 +255,10 @@ class GeoArray(object):
@property
def nodata(self):
"""Get the nodata value of the GeoArray. If GeoArray has been instanced with a file path the file is checked
for an existing nodata value. Otherwise (if no value is exlicitly given during object instanciation) the nodata
value is tried to be automatically detected.
"""
if self._nodata is not None:
return self._nodata
else:
......@@ -265,6 +284,7 @@ class GeoArray(object):
@property
def mask_nodata(self):
"""Get the nodata mask of the associated image array. It is calculated using all image bands."""
if self._mask_nodata is not None:
return self._mask_nodata
else:
......@@ -279,6 +299,7 @@ class GeoArray(object):
@property
def footprint_poly(self):
"""Get the footprint polygon of the associated image array (returns an instance of shapely.geometry.Polygon."""
if self._footprint_poly is not None:
return self._footprint_poly
else:
......@@ -288,8 +309,9 @@ class GeoArray(object):
self._footprint_poly = self.box.mapPoly
else:
try:
self._footprint_poly = raster2polygon(self, exact=False, progress=self.progress, q=self.q,
multipolygon = raster2polygon(self, exact=False, progress=self.progress, q=self.q,
maxfeatCount=10, timeout=3)
self._footprint_poly = fill_holes_within_poly(multipolygon)
except (RuntimeError, TimeoutError, TimeoutError_comp):
if not self.q:
warnings.warn("\nCalculation of footprint polygon failed for %s '%s'. Using outer bounds. One "
......@@ -542,25 +564,37 @@ class GeoArray(object):
return out_arr
def save(self, out_path, fmt='ENVI', q=False):
q = self.q if not q else q
if not q: print('Writing GeoArray of size %s to %s.' %(self.shape, out_path))
def save(self, out_path, fmt='ENVI', creationOptions=None):
"""Write the raster data to disk.
:param out_path: <str> output path
:param fmt: <str> the output format / GDAL driver code to be used for output creation, e.g. 'ENVI'
:param creationOptions: <list> GDAL creation options, e.g. ["QUALITY=20", "REVERSIBLE=YES", "WRITE_METADATA=YES"]
:return:
"""
if not self.q: print('Writing GeoArray of size %s to %s.' %(self.shape, out_path))
assert self.ndim in [2,3], 'Only 2D- or 3D arrays are supported.'
if not os.path.isdir(os.path.dirname(out_path)): os.makedirs(os.path.dirname(out_path))
if self.is_inmem:
ds = get_GDAL_ds_inmem(self.arr,self.geotransform, self.projection) # expects rows,columns,bands
out_arr = self.arr if self.ndim == 2 else np.swapaxes(np.swapaxes(self.arr, 0, 2), 1, 2) # rows, columns, bands => bands, rows, columns
gdalnumeric.SaveArray(out_arr, out_path, format=fmt, prototype=ds) # expects bands,rows,columns
ds = get_GDAL_ds_inmem(self.arr,self.geotransform, self.projection, self.nodata) # expects rows,columns,bands
gdal.GetDriverByName(fmt).CreateCopy(out_path, ds, options=creationOptions if creationOptions else [])
#out_arr = self.arr if self.ndim == 2 else np.swapaxes(np.swapaxes(self.arr, 0, 2), 1, 2) # rows, columns, bands => bands, rows, columns
#gdalnumeric.SaveArray(out_arr, out_path, format=fmt, prototype=ds) # expects bands,rows,columns
ds = None
else:
src_ds = gdal.Open(self.filePath)
gdal.Translate(out_path, src_ds, format=fmt)
gdal_Translate = get_gdal_func('Translate')
gdal_Translate(out_path, src_ds, format=fmt, creationOptions=creationOptions)
src_ds = None
def dump(self, out_path):
"""Sertialize the whole object instance to disk using dill."""
import dill
with open(out_path,'w') as outF:
dill.dump(self,outF)
......@@ -825,10 +859,13 @@ class GeoArray(object):
def cache_array_subset(self, subarray):
"""Sets the array cache of the GeoArray instance to the given array in order to speed up calculations
afterwards."""
self._arr_cache = subarray
def flush_cache(self):
"""Clear the array cache of the GeoArray instance."""
self._arr_cache = None
......
......@@ -14,14 +14,16 @@ except ImportError:
from ...compatibility.gdalnumeric import get_gdalnumeric_func
def get_GDAL_ds_inmem(array, gt=None, prj=None):
def get_GDAL_ds_inmem(array, gt=None, prj=None, nodata=None):
"""
:param array: <numpy.ndarray> in the shape (rows, columns, bands)
:param gt:
:param prj:
:param nodata: nodata value to be set
:return:
"""
# FIXME does not respect different nodata values for each band
if len(array.shape) == 3:
array = np.rollaxis(array, 2) # rows,cols,bands => bands,rows,cols
......@@ -30,6 +32,11 @@ def get_GDAL_ds_inmem(array, gt=None, prj=None):
ds = OpenNumPyArray(array)
if gt: ds.SetGeoTransform(gt)
if prj: ds.SetProjection(prj)
if nodata:
for i in range(ds.RasterCount):
band = ds.GetRasterBand(i+1)
band.SetNoDataValue(nodata)
band=None
ds.FlushCache() # Write to disk.
return ds
......
......@@ -4,57 +4,98 @@ __author__ = "Daniel Scheffler"
import sys
from time import time
_time_start = None
class Timer(object):
def __init__(self, timeout=None):
self.starttime = time()
self.endtime = self.starttime+timeout if timeout else None
@property
def timed_out(self):
if self.endtime:
if time() > self.endtime:
return True
else:
return False
else:
return False
@property
def elapsed(self):
return '%.2f sek' %(time()-self.starttime)
def is_timed_out(timeout):
if timeout is not None:
global _time_start
if _time_start is None:
_time_start = time()
if time() - _time_start >= timeout:
_time_start = None
return True
else:
return False
else:
return False
def __call__(self, percent01, message, user_data):
"""This allows that Timer instances are callable and thus can be used as callback function,
e.g. for GDAL.
:param percent01: this is not used but expected when used as GDAL callback
:param message: this is not used but expected when used as GDAL callback
:param user_data: this is not used but expected when used as GDAL callback
:return:
"""
return self.timed_out
def printProgress (percent, prefix = '', suffix = '', decimals = 1, barLength = 100, timeout=None):
"""
Call in a loop to create terminal progress bar
:param percent: a value between 0 and 100
:param prefix: - Optional : prefix string (Str)
:param suffix: - Optional : suffix string (Str)
:param decimals: - Optional : positive number of decimals in percent complete (Int)
:param barLength: - Optional : character length of bar (Int)
:param timeout: - Optional : breaks the process after a given time in seconds
http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
"""
if timeout:
if is_timed_out(timeout):
sys.stdout.write('\n')
class ProgressBar(object):
def __init__(self, prefix = '', suffix = 'Complete', decimals = 1, barLength = 50, show_elapsed=True,
timeout=None):
"""Call an instance of this class in a loop to create terminal progress bar. This class can also be used as
callback function, e.g. for GDAL. Just pass an instance of ProgressBar to the respective callback keyword.
:param prefix: prefix string (Str)
:param suffix: suffix string (Str)
:param decimals: positive number of decimals in percent complete (Int)
:param barLength: character length of bar (Int)
:param show_elapsed: displays the elapsed time right after the progress bar (bool)
:param timeout: breaks the process after a given time in seconds (float)
http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
"""
self.prefix = prefix
self.suffix = suffix
self.decimals = decimals
self.barLength = barLength
self.show_elapsed = show_elapsed
self.Timer = Timer(timeout=timeout)
def print_progress(self, percent):
"""Based on http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
:param percent: <float> a number between 0 and 100
:return:
"""
if self.Timer.timed_out:
sys.stdout.flush()
raise KeyboardInterrupt
formatStr = "{0:." + str(decimals) + "f}"
percents = formatStr.format(percent)
filledLength = int(round(barLength * percent/100))
#bar = '█' * filledLength + '-' * (barLength - filledLength) # this is not compatible to shell console
bar = '=' * filledLength + '-' * (barLength - filledLength)
sys.stdout.write('\r') # resets the cursor to the beginning of the line and allows to write over what was previously on the line
sys.stdout.write('%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),
formatStr = "{0:." + str(self.decimals) + "f}"
percents = formatStr.format(percent)
filledLength = int(round(self.barLength * percent / 100))
# bar = '█' * filledLength + '-' * (barLength - filledLength) # this is not compatible to shell console
bar = '=' * filledLength + '-' * (self.barLength - filledLength)
sys.stdout.write('\r') # resets the cursor to the beginning of the line and allows to write over what was previously on the line
# [%s/%s] numberDone
suffix = self.suffix if not self.show_elapsed else '%s => %s' %(self.suffix, self.Timer.elapsed)
sys.stdout.write('%s |%s| %s%s %s' % (self.prefix, bar, percents, '%', suffix))
if percent >= 100.:
sys.stdout.write('\n')
sys.stdout.flush()
if percent>=100.:
sys.stdout.write('\n')
global _time_start
_time_start = None
def __call__(self, percent01, message, user_data):
"""This allows that ProgressBar instances are callable and thus can be used as callback function,
e.g. for GDAL.
:param percent01: a float number between 0 and 1
:param message: this is not used but expected when used as GDAL callback
:param user_data: this is not used but expected when used as GDAL callback
:return:
"""
self.print_progress(percent01*100)
sys.stdout.flush()
def tqdm_hook(t):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment