Commit 2097bab5 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Fixed issue #47 (COREG_LOCAL.view_CoRegPoints() raises KeyError: 'X_SHIFT_M'...

Fixed issue #47

 (COREG_LOCAL.view_CoRegPoints() raises KeyError: 'X_SHIFT_M' error when there are too many clouds). Increased default figsize of COREG_LOCAL.view_CoRegPoints().
Signed-off-by: Daniel Scheffler's avatarDaniel Scheffler <danschef@gfz-potsdam.de>
parent 9d235f12
Pipeline #15618 passed with stage
in 2 minutes and 24 seconds
......@@ -378,7 +378,7 @@ class COREG_LOCAL(object):
options are 'ref' and 'tgt' (default: 'tgt')
:param hide_filtered: <bool> hide all points that have been filtered out according to tie point filter
level
:param figsize: <tuple> size of the figure to be viewed, e.g. (10,10)
:param figsize: <tuple> size of the figure to be viewed, e.g. (10, 10)
:param title: <str> plot title
:param vector_scale: <float> scale factor for shift vector length (default: 1 -> no scaling)
:param savefigPath:
......@@ -390,6 +390,7 @@ class COREG_LOCAL(object):
:return:
"""
from matplotlib import pyplot as plt # noqa
from matplotlib.offsetbox import AnchoredText
from cartopy.crs import PlateCarree
from mpl_toolkits.axes_grid1 import make_axes_locatable
......@@ -397,9 +398,16 @@ class COREG_LOCAL(object):
if backgroundIm not in ['tgt', 'ref']:
raise ValueError('backgroundIm')
backgroundIm = self.im2shift if backgroundIm == 'tgt' else self.imref
fig, ax = backgroundIm.show_map(figsize=figsize, nodataVal=self.nodata[1], return_map=True,
fig, ax = backgroundIm.show_map(figsize=figsize,
nodataVal=self.nodata[1],
return_map=True,
band=self.COREG_obj.shift.band4match)
# make sure the output figure has a reasonable size, also if figsize is not given
if not figsize:
w, h = fig.get_size_inches()
fig.set_size_inches(w * 1.6, h * 1.6)
# set figure title
dict_attr_title = dict(
X_WIN_SIZE='size of the matching window in x-direction [pixels]',
......@@ -425,109 +433,121 @@ class COREG_LOCAL(object):
raise ValueError(attribute2plot, "Invalid value for 'attribute2plot'. Valid values are: %s."
% ", ".join(self.CoRegPoints_table.columns))
# get GeoDataFrame containing everything needed for plotting
outlierCols = [c for c in self.CoRegPoints_table.columns if 'OUTLIER' in c]
attr2include = ['geometry', attribute2plot] + outlierCols + ['X_SHIFT_M', 'Y_SHIFT_M']
GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.X_SHIFT_M != self.outFillVal, attr2include].copy()\
if exclude_fillVals else self.CoRegPoints_table.loc[:, attr2include]
# get LonLat coordinates for all points
XY = np.array([geom.coords.xy for geom in GDF.geometry]).reshape(-1, 2)
lon, lat = transform_any_prj(self.im2shift.projection, 4326, XY[:, 0], XY[:, 1])
GDF['Lon'], GDF['Lat'] = lon, lat
# get colors for all points
palette = cmap if cmap is not None else plt.cm.get_cmap('RdYlGn_r')
if cmap is None and attribute2plot == 'ANGLE':
import cmocean
palette = getattr(cmocean.cm, 'delta')
if hide_filtered:
if self.tieP_filter_level > 0:
GDF = GDF[GDF.L1_OUTLIER.__eq__(False)].copy()
if self.tieP_filter_level > 1:
GDF = GDF[GDF.L2_OUTLIER.__eq__(False)].copy()
if self.tieP_filter_level > 2:
GDF = GDF[GDF.L3_OUTLIER.__eq__(False)].copy()
else:
marker = 'o' if len(GDF) < 10000 else '.'
common_kw = dict(marker=marker, alpha=1.0, transform=PlateCarree())
if self.tieP_filter_level > 0:
# flag level 1 outliers
GDF_filt = GDF[GDF.L1_OUTLIER.__eq__(True)].copy()
ax.scatter(GDF_filt['Lon'], GDF_filt['Lat'], c='b', s=250, label='reliability',
**common_kw)
if self.tieP_filter_level > 1:
# flag level 2 outliers
GDF_filt = GDF[GDF.L2_OUTLIER.__eq__(True)].copy()
ax.scatter(GDF_filt['Lon'], GDF_filt['Lat'], c='r', s=150, label='SSIM',
**common_kw)
if self.tieP_filter_level > 2:
# flag level 3 outliers
GDF_filt = GDF[GDF.L3_OUTLIER.__eq__(True)].copy()
ax.scatter(GDF_filt['Lon'], GDF_filt['Lat'], c='y', s=250, label='RANSAC',
**common_kw)
if self.tieP_filter_level > 0:
ax.legend(loc=0, scatterpoints=1)
# plot all points or vectors on top
if not GDF.empty:
vmin_auto, vmax_auto = \
(np.percentile(GDF[attribute2plot], 0),
np.percentile(GDF[attribute2plot], 98)) \
if attribute2plot != 'ANGLE' else (0, 360)
vmin = vmin if vmin is not None else vmin_auto
vmax = vmax if vmax is not None else vmax_auto
if shapes2plot == 'vectors':
# plot shift vectors
# doc: https://matplotlib.org/devdocs/api/_as_gen/matplotlib.axes.Axes.quiver.html
mappable = ax.quiver(
GDF['Lon'].values, GDF['Lat'].values,
-GDF['X_SHIFT_M'].values,
-GDF['Y_SHIFT_M'].values, # invert absolute shifts to make arrows point to tgt
GDF[attribute2plot].clip(vmin, vmax), # sets the colors
scale=1200 / vector_scale, # larger values decrease the arrow length
width=.0015, # arrow width (in relation to plot width)
# linewidth=1, # maybe use this to mark outliers instead of scatter points
cmap=palette,
pivot='middle', # position the middle point of the arrows onto the tie point location
transform=PlateCarree()
)
elif shapes2plot == 'points':
# plot tie points
mappable = ax.scatter(
GDF['Lon'], GDF['Lat'],
c=GDF[attribute2plot],
lw=0,
cmap=palette,
marker='o' if len(GDF) < 10000 else '.',
s=50,
alpha=1.0,
vmin=vmin,
vmax=vmax,
transform=PlateCarree())
pass
if not self.CoRegPoints_table.empty:
# get GeoDataFrame containing everything needed for plotting
outlierCols = [c for c in self.CoRegPoints_table.columns if 'OUTLIER' in c]
attr2include = ['geometry', attribute2plot] + outlierCols + ['X_SHIFT_M', 'Y_SHIFT_M']
GDF = self.CoRegPoints_table.loc[self.CoRegPoints_table.X_SHIFT_M != self.outFillVal, attr2include].copy()\
if exclude_fillVals else self.CoRegPoints_table.loc[:, attr2include]
# get LonLat coordinates for all points
XY = np.array([geom.coords.xy for geom in GDF.geometry]).reshape(-1, 2)
lon, lat = transform_any_prj(self.im2shift.projection, 4326, XY[:, 0], XY[:, 1])
GDF['Lon'], GDF['Lat'] = lon, lat
# get colors for all points
palette = cmap if cmap is not None else plt.cm.get_cmap('RdYlGn_r')
if cmap is None and attribute2plot == 'ANGLE':
import cmocean
palette = getattr(cmocean.cm, 'delta')
if hide_filtered:
if self.tieP_filter_level > 0:
GDF = GDF[GDF.L1_OUTLIER.__eq__(False)].copy()
if self.tieP_filter_level > 1:
GDF = GDF[GDF.L2_OUTLIER.__eq__(False)].copy()
if self.tieP_filter_level > 2:
GDF = GDF[GDF.L3_OUTLIER.__eq__(False)].copy()
else:
raise ValueError("The parameter 'shapes2plot' must be set to 'vectors' or 'points'. "
"Received %s." % shapes2plot)
marker = 'o' if len(GDF) < 10000 else '.'
common_kw = dict(marker=marker, alpha=1.0, transform=PlateCarree())
if self.tieP_filter_level > 0:
# flag level 1 outliers
GDF_filt = GDF[GDF.L1_OUTLIER.__eq__(True)].copy()
ax.scatter(GDF_filt['Lon'], GDF_filt['Lat'], c='b', s=250, label='reliability',
**common_kw)
if self.tieP_filter_level > 1:
# flag level 2 outliers
GDF_filt = GDF[GDF.L2_OUTLIER.__eq__(True)].copy()
ax.scatter(GDF_filt['Lon'], GDF_filt['Lat'], c='r', s=150, label='SSIM',
**common_kw)
if self.tieP_filter_level > 2:
# flag level 3 outliers
GDF_filt = GDF[GDF.L3_OUTLIER.__eq__(True)].copy()
ax.scatter(GDF_filt['Lon'], GDF_filt['Lat'], c='y', s=250, label='RANSAC',
**common_kw)
if self.tieP_filter_level > 0:
ax.legend(loc=0, scatterpoints=1)
# plot all points or vectors on top
if not GDF.empty:
vmin_auto, vmax_auto = \
(np.percentile(GDF[attribute2plot], 0),
np.percentile(GDF[attribute2plot], 98)) \
if attribute2plot != 'ANGLE' else (0, 360)
vmin = vmin if vmin is not None else vmin_auto
vmax = vmax if vmax is not None else vmax_auto
if shapes2plot == 'vectors':
# plot shift vectors
# doc: https://matplotlib.org/devdocs/api/_as_gen/matplotlib.axes.Axes.quiver.html
mappable = ax.quiver(
GDF['Lon'].values, GDF['Lat'].values,
-GDF['X_SHIFT_M'].values,
-GDF['Y_SHIFT_M'].values, # invert absolute shifts to make arrows point to tgt
GDF[attribute2plot].clip(vmin, vmax), # sets the colors
scale=1200 / vector_scale, # larger values decrease the arrow length
width=.0015, # arrow width (in relation to plot width)
# linewidth=1, # maybe use this to mark outliers instead of scatter points
cmap=palette,
pivot='middle', # position the middle point of the arrows onto the tie point location
transform=PlateCarree()
)
elif shapes2plot == 'points':
# plot tie points
mappable = ax.scatter(
GDF['Lon'], GDF['Lat'],
c=GDF[attribute2plot],
lw=0,
cmap=palette,
marker='o' if len(GDF) < 10000 else '.',
s=50,
alpha=1.0,
vmin=vmin,
vmax=vmax,
transform=PlateCarree())
pass
else:
raise ValueError("The parameter 'shapes2plot' must be set to 'vectors' or 'points'. "
"Received %s." % shapes2plot)
# add colorbar
divider = make_axes_locatable(ax)
cax = divider.new_vertical(size="2%", pad=0.4, pack_start=True,
axes_class=plt.Axes # needed because ax is a GeoAxis instance
)
fig.add_axes(cax)
fig.colorbar(mappable, cax=cax, orientation="horizontal")
# hack to enlarge the figure on the top to avoid cutting off the title (everthing else has no effect)
divider.new_vertical(size="2%", pad=0.4, pack_start=False, axes_class=plt.Axes)
# add colorbar
divider = make_axes_locatable(ax)
cax = divider.new_vertical(size="2%", pad=0.4, pack_start=True,
axes_class=plt.Axes # needed because ax is a GeoAxis instance
)
fig.add_axes(cax)
fig.colorbar(mappable, cax=cax, orientation="horizontal")
else:
msg = "The map does not contain any tie points \n" \
"because all the found tie points were flagged as false-positives."
ax.add_artist(AnchoredText(msg, loc='lower center', prop=dict(c='r')))
# hack to enlarge the figure on the top to avoid cutting off the title (everthing else has no effect)
divider.new_vertical(size="2%", pad=0.4, pack_start=False, axes_class=plt.Axes)
if not self.q:
warnings.warn(msg)
else:
msg = "The map does not contain any tie points because no tie points were found at all."
ax.add_artist(AnchoredText(msg, loc='lower center', prop=dict(c='r')))
if not self.q:
warnings.warn('Cannot plot any tie point because none is left after tie point validation.')
warnings.warn(msg)
if savefigPath:
fig.savefig(savefigPath, dpi=savefigDPI)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment