Commit 9cda6ac5 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

bugfix for running raster2polygon without timeout

compatibility:
- added subpackage 'python' for handling compatibility issues between python versions

geo.raster.conversion:
- revised raster2polygon(): added timeout, progress bar, quiet mode

io.raster.GeoArray:
- added progress keyword
- nodata: bugfix
- footprint_poly: now returns outer box if raster2polygon timed out
- show(): bugfix for crash in case input image has only one value
- show_map(): bugfix for crash in case input image has only one value

processing.progress_mon:
- added function is_timed_out()
- printProgress(): added timeout keyword
parent 43a9e95f
__author__='Daniel Scheffler'
from . import gdal
__all__=['gdal',
from . import python
__all__=['python'
'gdal',
'gdalnumeric']
\ No newline at end of file
__author__='Daniel Scheffler'
from . import exceptions
__all__=['exceptions']
\ No newline at end of file
# -*- coding: utf-8 -*-
__author__ = "Daniel Scheffler"
class TimeoutError(OSError):
""" Timeout expired. """
pass
\ No newline at end of file
......@@ -14,16 +14,21 @@ except ImportError:
from osgeo import ogr
from osgeo import osr
from ...io.raster.gdal import get_GDAL_ds_inmem
from ...io.raster.gdal import get_GDAL_ds_inmem
from ...processing.progress_mon import printProgress, is_timed_out
from ...compatibility.python.exceptions import TimeoutError
def raster2polygon(array_or_GeoArray, gt=None, prj=None, exact=True):
def raster2polygon(array_or_GeoArray, gt=None, prj=None, exact=True, timeout=None, progress=True, q=False):
"""Calculates a footprint polygon for the given array or GeoArray.
:param array_or_GeoArray:
:param gt:
:param prj:
:param exact:
:param timeout: breaks the process after a given time in seconds
:param progress: show progress bars (default: True)
:param q: quiet mode (default: False)
:return:
"""
from ... import GeoArray
......@@ -44,11 +49,23 @@ def raster2polygon(array_or_GeoArray, gt=None, prj=None, exact=True):
fd = ogr.FieldDefn('DN', ogr.OFTInteger)
mem_layer.CreateField(fd)
# run the algorithm.
result = gdal.Polygonize(src_band, src_band.GetMaskBand(), mem_layer, 0, ["8CONNECTED=8"] if exact else [])
# set callback
progressBar = lambda percent01, message, user_data: \
printProgress(percent01 * 100, prefix='Polygonize progress ', suffix='Complete', barLength=50, timeout=3)
timeout_callback = lambda percent01, message, user_data: is_timed_out(3)
callback = progressBar if progress and not q else timeout_callback if timeout else None
# run the algorithm
result = gdal.Polygonize(src_band, src_band.GetMaskBand(), mem_layer, 0, ["8CONNECTED=8"] if exact else [],
callback=callback)
errMsg = gdal.GetLastErrorMsg()
if errMsg == 'User terminated':
raise TimeoutError('raster2polygon timed out!')
if result is None:
raise Exception(gdal.GetLastErrorMsg())
raise Exception(errMsg)
# extract polygon
mem_layer.SetAttributeFilter('DN = 1')
......
......@@ -21,15 +21,16 @@ except ImportError:
import gdalnumeric
from ...geo.coord_calc import get_corner_coordinates, calc_FullDataset_corner_positions
from ...geo.coord_grid import snap_bounds_to_pixGrid
from ...geo.coord_trafo import mapXY2imXY, imXY2mapXY, transform_any_prj, reproject_shapelyGeometry
from ...geo.projection import prj_equal, WKT2EPSG, EPSG2WKT
from ...geo.raster.conversion import raster2polygon
from ...geo.vector.topology import get_overlap_polygon, get_footprint_polygon
from ...geo.vector.geometry import boxObj
from ...io.raster.gdal import get_GDAL_ds_inmem
from ...numeric.array import find_noDataVal, get_outFillZeroSaturated
from ...geo.coord_calc import get_corner_coordinates, calc_FullDataset_corner_positions
from ...geo.coord_grid import snap_bounds_to_pixGrid
from ...geo.coord_trafo import mapXY2imXY, imXY2mapXY, transform_any_prj, reproject_shapelyGeometry
from ...geo.projection import prj_equal, WKT2EPSG, EPSG2WKT
from ...geo.raster.conversion import raster2polygon
from ...geo.vector.topology import get_overlap_polygon, get_footprint_polygon
from ...geo.vector.geometry import boxObj
from ...io.raster.gdal import get_GDAL_ds_inmem
from ...numeric.array import find_noDataVal, get_outFillZeroSaturated
from ...compatibility.python.exceptions import TimeoutError
......@@ -41,7 +42,8 @@ def _alias_property(key):
class GeoArray(object):
def __init__(self, path_or_array, geotransform=None, projection=None, bandnames=None, nodata=None, q=False):
def __init__(self, path_or_array, geotransform=None, projection=None, bandnames=None, nodata=None, progress=True,
q=False):
# type: (Any, tuple, str, list) -> GeoArray
"""
......@@ -52,6 +54,7 @@ class GeoArray(object):
:param bandnames: names of the bands within the input array, e.g. ['mask_1bit', 'mask_clouds'],
(default: ['B1', 'B2', 'B3', ...])
:param nodata: nodata value
:param progress: show progress bars (default: True)
:param q: quiet mode (default: False)
"""
......@@ -249,7 +252,7 @@ class GeoArray(object):
warnings.warn('Nodata value could not be clearly identified. It has been set to None.')
self._nodata = None
else:
if not self.q:
if self._nodata is not None and not self.q:
print("Automatically detected nodata value for %s '%s': %s"
%(self.__class__.__name__, self.basename, self._nodata))
return self._nodata
......@@ -280,7 +283,19 @@ class GeoArray(object):
return self._footprint_poly
else:
assert self.mask_nodata is not None, 'A nodata mask is needed for calculating the footprint polygon. '
self._footprint_poly = raster2polygon(self, exact=False)
if np.std(self.mask_nodata)==0:
# do not run raster2polygon if whole image is filled with data
self._footprint_poly = self.box.mapPoly
else:
try:
self._footprint_poly = raster2polygon(self, exact=False, progress=self.progress, q=self.q, timeout=10)
except TimeoutError:
if not self.q:
warnings.warn("\nCalculation of footprint polygon failed for %s '%s'. Using outer bounds. One "
"reason could be that the nodata value appears in the middle of the actual image. "
"To avoid this use another nodata value. Current nodata value is %s."
%(self.__class__.__name__, self.basename, self.nodata))
self._footprint_poly = self.box.mapPoly
return self._footprint_poly
......@@ -389,7 +404,7 @@ class GeoArray(object):
self._dtype = gdal_array.GDALTypeCodeToNumericTypeCode(ds.GetRasterBand(1).DataType)
self._geotransform = ds.GetGeoTransform()
self._projection = ds.GetProjection()
if not 'nodata' in self._initParams:
if not 'nodata' in self._initParams or self._initParams['nodata'] is None:
band = ds.GetRasterBand(1)
self._nodata = band.GetNoDataValue() # FIXME this does not support different nodata values within the same file
ds = band = None
......@@ -506,6 +521,7 @@ class GeoArray(object):
def save(self, out_path, fmt='ENVI', q=False):
q = self.q if not q else q
if not q: print('Writing GeoArray of size %s to %s.' %(self.shape, out_path))
assert self.ndim in [2,3], 'Only 2D- or 3D arrays are supported.'
......@@ -582,7 +598,7 @@ class GeoArray(object):
# set color palette
palette = cmap if cmap else plt.cm.gray
if nodataVal is not None: # do not show nodata
if nodataVal is not None and np.std(image2plot)!=0: # do not show nodata
image2plot = np.ma.masked_equal(image2plot, nodataVal)
vmin, vmax = np.percentile(image2plot.compressed(),2), np.percentile(image2plot.compressed(),98)
palette.set_bad('aqua', 0)
......@@ -637,7 +653,7 @@ class GeoArray(object):
# set color palette
palette = cmap if cmap else plt.cm.gray
if nodataVal is not None: # do not show nodata
if nodataVal is not None and np.std(image2plot)!=0: # do not show nodata
image2plot = np.ma.masked_equal(image2plot, nodataVal)
vmin, vmax = np.percentile(image2plot.compressed(), 2), np.percentile(image2plot.compressed(), 98)
palette.set_bad('aqua', 0)
......
......@@ -2,9 +2,25 @@
__author__ = "Daniel Scheffler"
import sys
from time import time
_time_start = None
def printProgress (percent, prefix = '', suffix = '', decimals = 1, barLength = 100):
def is_timed_out(timeout):
global _time_start
if _time_start is None and timeout:
_time_start = time()
if timeout and time() - _time_start >= timeout:
_time_start = None
return True
else:
return False
def printProgress (percent, prefix = '', suffix = '', decimals = 1, barLength = 100, timeout=None):
"""
Call in a loop to create terminal progress bar
:param percent: a value between 0 and 100
......@@ -12,9 +28,16 @@ def printProgress (percent, prefix = '', suffix = '', decimals = 1, barLength =
:param suffix: - Optional : suffix string (Str)
:param decimals: - Optional : positive number of decimals in percent complete (Int)
:param barLength: - Optional : character length of bar (Int)
:param timeout: - Optional : breaks the process after a given time in seconds
http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
"""
if timeout:
if is_timed_out(timeout):
sys.stdout.write('\n')
sys.stdout.flush()
raise KeyboardInterrupt
formatStr = "{0:." + str(decimals) + "f}"
percents = formatStr.format(percent)
filledLength = int(round(barLength * percent/100))
......@@ -22,8 +45,10 @@ def printProgress (percent, prefix = '', suffix = '', decimals = 1, barLength =
bar = '=' * filledLength + '-' * (barLength - filledLength)
sys.stdout.write('\r') # resets the cursor to the beginning of the line and allows to write over what was previously on the line
sys.stdout.write('%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),
if percent>=100.:
sys.stdout.write('\n')
sys.stdout.flush()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment