Commit bea75bf5 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Bugfix for returning wrong array shape when warping a 3D array

compatibility.gdalnumeric:
- OpenNumPyArray(): Bugfix for expecting (rows,columns,bands) instead of GDAL-like (bands,rows,columns)
- edited docstring
- added datatype assertion

geo.raster.reproject:
- moved availability check for resampling algorithm 'average' here

io.raster.gdal:
- get_GDAL_ds_inmem(): added docstring
- added get_GDAL_driverList()

io.raster.GeoArray:
- save(): bugfix for writing wrong array dimensions in case of 3D array
- show(): nodataVal is now excluded from vmin/vmax calculation when showing image

processing.progress_mon:
- printProgress(): changed bar symbol due to incompatibility to csh shell output stream
parent 113d3c9d
......@@ -7,7 +7,7 @@ try:
from osgeo import gdalconst
except ImportError:
import gdal
import gdalnumeric
import gdalnumeric # FIXME this will import this __module__
import gdalconst
......@@ -15,12 +15,21 @@ def OpenNumPyArray(array):
"""This function emulates the functionality of gdalnumeric.OpenNumPyArray() which is not available in GDAL versions
below 2.1.0 (?).
:param array:
:param array: <numpy.ndarray> in the shape (bands, rows, columns)
:return:
"""
rows, cols = array.shape[:2]
bands = array.shape[2] if array.ndim == 3 else 1
gdal_dtype = gdalnumeric.NumericTypeCodeToGDALTypeCode(array.dtype)
if array.ndim==2:
rows, cols = array.shape
bands = 1
elif array.ndim==3:
bands,rows,cols=array.shape
else:
raise ValueError('OpenNumPyArray() currently only supports 2D and 3D arrays. Given array shape is %s.'
%str(array.shape))
# get output datatype
gdal_dtype = gdalnumeric.NumericTypeCodeToGDALTypeCode(array.dtype) # FIXME not all datatypes can be translated
assert gdal_dtype is not None, 'Datatype %s is currently not supported by OpenNumPyArray().' %array.dtype
mem_drv = gdal.GetDriverByName('MEM')
mem_ds = mem_drv.Create('/vsimem/tmp/memfile.mem', cols, rows, bands, gdal_dtype)
......@@ -39,7 +48,7 @@ def OpenNumPyArray(array):
def get_gdalnumeric_func(funcName):
try:
return getattr(gdal, funcName)
return getattr(gdalnumeric, funcName)
except AttributeError:
if funcName in globals():
return globals()[funcName]
......
......@@ -268,7 +268,15 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=(
# TODO test if this function delivers the exact same output like console version, otherwise implment error_threshold=0.125
# how to implement: https://svn.osgeo.org/gdal/trunk/autotest/utilities/test_gdalwarp_lib.py
# assertions
assert str(np.dtype(ndarray.dtype)) in dTypeDic_NumPy2GDAL, "Unknown target datatype '%s'." %ndarray.dtype
if rspAlg=='average':
is_avail_rsp_average = int(gdal.VersionInfo()[0]) >= 2
if not is_avail_rsp_average:
warnings.warn("The GDAL version on this machine does not yet support the resampling algorithm 'average'. "
"'cubic' is used instead. To avoid this please update GDAL to a version above 2.0.0!")
rspAlg='cubic'
get_SRS = lambda prjArg: prjArg if isinstance(prjArg,str) and prjArg.startswith('EPSG:') else \
'EPSG:%s'%prjArg if isinstance(prjArg,int) else \
......@@ -279,14 +287,12 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=(
progressBarWarp = (lambda percent01, message, user_data: printProgress(percent01 * 100,
**{'prefix': 'Warping progress ', 'suffix': 'Complete', 'barLength': 50})) if progress and not q else None
# not yet implemented
cutlineDSName = 'data/cutline.vrt' #'/vsimem/cutline.shp' TODO cutline from OGR datasource. => implement input shapefile or Geopandas dataframe
cutlineLayer = 'cutline'
cropToCutline = False
cutlineSQL = 'SELECT * FROM cutline'
cutlineWhere = '1 = 1'
callback_data = [0]
rpc = [
"HEIGHT_OFF=1466.05894327379",
"HEIGHT_SCALE=144.837606185489",
......@@ -343,7 +349,7 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=(
metadataConflictValue --- metadata data conflict value
setColorInterpretation --- whether to force color interpretation of input bands to output bands
callback --- callback method
callback_data --- user data for callback
callback_data --- user data for callback # value for last parameter of progress callback
"""
......@@ -394,6 +400,7 @@ def warp_ndarray(ndarray, in_gt, in_prj, out_prj=None, out_dtype=None, out_gsd=(
polynomialOrder = polynomialOrder,
warpMemoryLimit = warpMemoryLimit,
callback = progressBarWarp,
callback_data = [0],
errorThreshold = 0.125, # this is needed to get exactly the same output like the console version of GDAL warp
)
......
......@@ -31,6 +31,7 @@ from ...io.raster.gdal import get_GDAL_ds_inmem
class GeoArray(object):
# TODO automatic nodataVal detection => add that value to __init__, adjust calls within CoRegSat
def __init__(self, path_or_array, geotransform=None, projection=None, bandnames=None):
# type: (Any, tuple, str, list) -> GeoArray
"""
......@@ -401,9 +402,9 @@ class GeoArray(object):
assert self.ndim in [2,3], 'Only 2D- or 3D arrays are supported.'
if self.is_inmem:
out_arr = self.arr if self.ndim==2 else np.swapaxes(np.swapaxes(self.arr,0,2),1,2) # rows, columns, bands => bands, rows, columns
ds = get_GDAL_ds_inmem(out_arr,self.geotransform, self.projection)
gdalnumeric.SaveArray(out_arr, out_path, format=fmt, prototype=ds)
ds = get_GDAL_ds_inmem(self.arr,self.geotransform, self.projection) # expects rows,columns,bands
out_arr = self.arr if self.ndim == 2 else np.swapaxes(np.swapaxes(self.arr, 0, 2), 1, 2) # rows, columns, bands => bands, rows, columns
gdalnumeric.SaveArray(out_arr, out_path, format=fmt, prototype=ds) # expects bands,rows,columns
ds = None
else:
src_ds = gdal.Open(self.filePath)
......@@ -469,16 +470,19 @@ class GeoArray(object):
# set color palette
palette = cmap if cmap else plt.cm.gray
if nodataVal is not None: # do not show nodata
if nodataVal is not None: # do not show nodata # add auto-detection # TODO
image2plot = np.ma.masked_equal(image2plot, nodataVal)
vmin, vmax = np.percentile(image2plot.compressed(),2), np.percentile(image2plot.compressed(),98)
palette.set_bad('aqua', 0)
else:
vmin, vmax = np.percentile(image2plot, 2), np.percentile(image2plot, 98)
palette.set_over ('1')
palette.set_under('0')
# show image
plt.figure(figsize=figsize)
plt.imshow(image2plot, palette, interpolation=interpolation,extent=(0,self.cols,self.rows,0),
vmin=np.percentile(image2plot,2),vmax=np.percentile(image2plot,98),)
vmin=vmin,vmax=vmax,) # compressed excludes nodata values
plt.show()
......
......@@ -3,10 +3,26 @@ __author__ = "Daniel Scheffler"
import numpy as np
from pandas import DataFrame
try:
from osgeo import gdal
except ImportError:
import gdal # FIXME this will import this __module__
from ...compatibility.gdalnumeric import get_gdalnumeric_func
def get_GDAL_ds_inmem(array, gt, prj):
"""
:param array: <numpy.ndarray> in the shape (rows, columns, bands)
:param gt:
:param prj:
:return:
"""
if len(array.shape) == 3:
array = np.rollaxis(array, 2) # rows,cols,bands => bands,rows,cols
......@@ -15,4 +31,21 @@ def get_GDAL_ds_inmem(array, gt, prj):
ds.SetGeoTransform(gt)
ds.SetProjection(prj)
ds.FlushCache() # Write to disk.
return ds
\ No newline at end of file
return ds
def get_GDAL_driverList():
count = gdal.GetDriverCount()
df = DataFrame(np.full((count,5), np.nan),columns=['drvCode','drvLongName', 'ext1', 'ext2', 'ext3'])
for i in range(count):
drv = gdal.GetDriver(i)
if drv.GetMetadataItem(gdal.DCAP_RASTER):
meta = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
extensions = meta.split() if meta else []
df.ix[i]=[drv.GetDescription(),
drv.GetMetadataItem(gdal.DMD_LONGNAME),
extensions[0] if len(extensions)>0 else np.nan,
extensions[1] if len(extensions)>1 else np.nan,
extensions[2] if len(extensions)>2 else np.nan]
df = df.dropna(how='all')
return df
\ No newline at end of file
......@@ -19,8 +19,10 @@ def printProgress (percent, prefix = '', suffix = '', decimals = 1, barLength =
percents = formatStr.format(percent)
filledLength = int(round(barLength * percent/100))
bar = '█' * filledLength + '-' * (barLength - filledLength)
sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),
if percent==100.:
bar = '=' * filledLength + '-' * (barLength - filledLength)
sys.stdout.write('\r') # resets the cursor to the beginning of the line and allows to write over what was previously on the line
sys.stdout.write('%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),
if percent>=100.:
sys.stdout.write('\n')
sys.stdout.flush()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment