Commit 6e1874b9 authored by jandi's avatar jandi
Browse files

imports an optional majority param

uses majority filter to smooth output
parent b5737097
Pipeline #2986 failed with stages
in 3 minutes and 15 seconds
......@@ -39,7 +39,7 @@ class COREG_LOCAL(object):
footprint_poly_ref=None, footprint_poly_tgt=None, data_corners_ref=None, data_corners_tgt=None,
outFillVal=-9999, nodata=(None, None), calc_corners=True, binary_ws=True, force_quadratic_win=True,
mask_baddata_ref=None, mask_baddata_tgt=None, CPUs=None, progress=True, v=False, q=False,
ignore_errors=True, dem=None):
ignore_errors=True, dem=None, majority_filter=None):
"""Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
......@@ -146,6 +146,7 @@ class COREG_LOCAL(object):
:param q(bool): quiet mode (default: False)
:param ignore_errors(bool): Useful for batch processing. (default: False)
:param dem(str, GeoaArray): Optional DEM for reliability improvement
:param majority_filter(tuple): Optional (X,X)kernel majority filter of neighbouring pixels
# assertions / input validation
......@@ -193,6 +194,7 @@ class COREG_LOCAL(object):
self.progress = progress if not q else False # overridden by v
self.ignErr = ignore_errors # FIXME this is not yet implemented for COREG_LOCAL
self.dem = dem
self.majority_filter = majority_filter
assert self.tieP_filter_level in range(4), 'Invalid tie point filter level.'
assert isinstance(self.imref, GeoArray) and isinstance(self.im2shift, GeoArray), \
......@@ -310,7 +312,8 @@ class COREG_LOCAL(object):
if self.v:
......@@ -17,6 +17,7 @@ from geopandas import GeoDataFrame, GeoSeries
from shapely.geometry import Point
from skimage.measure import points_in_poly, ransac
from skimage.transform import AffineTransform, PolynomialTransform
from functools import partial
# internal modules
from .CoReg import COREG
......@@ -51,7 +52,7 @@ class Tie_Point_Grid(object):
def __init__(self, COREG_obj, grid_res, max_points=None, outFillVal=-9999, resamp_alg_calc='cubic',
tieP_filter_level=3, outlDetect_settings=None, dir_out=None, CPUs=None, progress=True, v=False,
q=False, majority_filter=False):
"""Applies the algorithm to detect spatial shifts to the whole overlap area of the input images. Spatial shifts
are calculated for each point in grid of which the parameters can be adjusted using keyword arguments. Shift
......@@ -91,6 +92,7 @@ class Tie_Point_Grid(object):
:param progress(bool): show progress bars (default: True)
:param v(bool): verbose mode (default: False)
:param q(bool): quiet mode (default: False)
:param majority_filter(tuple) Optional use of (X,X)kernel majority filter of the neighbourhood
if not isinstance(COREG_obj, COREG):
......@@ -107,6 +109,7 @@ class Tie_Point_Grid(object):
self.CPUs = CPUs
self.v = v
self.q = q if not v else False # overridden by v
self.majority_filter = majority_filter
self.progress = progress if not q else False # overridden by q
self.ref = self.COREG_obj.ref # type: GeoArray_CoReg
......@@ -393,10 +396,95 @@ class Tie_Point_Grid(object):
GDF = GDF.merge(GDF_filt[['POINT_ID'] + new_columns], on='POINT_ID', how="outer")
GDF = GDF.fillna(int(self.outFillVal))
if self.majority_filter is not None:
GDF = self.use_majority_filter(GDF)
self.CoRegPoints_table = GDF
return self.CoRegPoints_table
def use_majority_filter(self, coreg_table):
if self.CPUs is None or self.CPUs > 1:
cpus = multiprocessing.cpu_count()
print('Use majority filter with %s CPUs...' % cpus)
pool = multiprocessing.Pool(processes=cpus)
shift = partial(self._find_shift,
grid_res=(self.COREG_obj.win_size_XY[0], self.COREG_obj.win_size_XY[1]),
points = ((row.X_IM, row.Y_IM) for i, row in coreg_table.iterrows()) # generator
if self.q or not self.progress:
results =, points)
results = pool.map_async(shift, points, chunksize=1)
bar = ProgressBar(prefix='\tprogress:')
while True:
numberDone = len(coreg_table) - results._number_left
if self.progress:
bar.print_progress(percent=numberDone / len(coreg_table) * 100)
if results.ready():
# <= this is the line where multiprocessing can freeze if an exception appears within
# COREG ans is not raised
results = results.get()
coreg_table['ABS_SHIFT'] = [i[0] for i in results]
coreg_table['RELIABILITY'] = [i[1] for i in results]
print('Use majority filter with 1 CPU...')
bar = ProgressBar(prefix='\tprogress:')
for i, row in coreg_table.iterrows():
point = (row.X_IM, row.Y_IM)
bar.print_progress(percent=i / len(coreg_table) * 100)
shift, reliability = self._find_shift(point, coreg_table, self.majority_filter)
(coreg_table.X_IM == point[0]) & (coreg_table.Y_IM == point[1]), 'ABS_SHIFT'] = shift
(coreg_table.X_IM == point[0]) & (coreg_table.Y_IM == point[1]), 'RELIABILITY'] = reliability
return coreg_table
def _find_shift(point, coreg_table, grid_res, majority_kernel):
shifts = []
reli = []
for i in range(-majority_kernel[0], majority_kernel[0], 1):
for j in range(-majority_kernel[1], majority_kernel[1], 1):
x_coord = point[0] + grid_res[0]*i
y_coord = point[1] + grid_res[1]*j
if x_coord in coreg_table[coreg_table.Y_IM == y_coord].X_IM.values:
if coreg_table[(coreg_table.X_IM == x_coord) & (coreg_table.Y_IM == y_coord)]\
['ABS_SHIFT'].item() >= 0:
shifts.append(coreg_table[(coreg_table.X_IM == x_coord) & (coreg_table.Y_IM == y_coord)]\
reli.append(coreg_table[(coreg_table.X_IM == x_coord) & (coreg_table.Y_IM == y_coord)]\
return np.mean(shifts), np.mean(reli)
def calc_rmse(self, include_outliers=False):
# type: (bool) -> float
"""Calculates root mean square error of absolute shifts from the tie point grid.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment