Commit 2842b108 authored by Sebastian Heimann's avatar Sebastian Heimann
Browse files

support l1norm fitting

parent 80237d45
...@@ -236,64 +236,95 @@ class CMTProblem(core.Problem): ...@@ -236,64 +236,95 @@ class CMTProblem(core.Problem):
return self._target_weights return self._target_weights
def raise_invalid_norm_exponent(self):
raise core.GrondError('invalid norm exponent' % self.norm_exponent)
def get_sqr_sqrt(self):
if self.norm_exponent == 2:
def sqr(x):
return x**2
return sqr, num.sqrt
elif self.norm_exponent == 1:
def noop(x):
return x
return noop, num.abs
else:
self.raise_invalid_norm_exponent()
def inter_group_weights(self, ns): def inter_group_weights(self, ns):
sqr, sqrt = self.get_sqr_sqrt()
group, ngroups = self.get_group_mask() group, ngroups = self.get_group_mask()
ws = num.zeros(self.ntargets) ws = num.zeros(self.ntargets)
for igroup in xrange(ngroups): for igroup in xrange(ngroups):
mask = group == igroup mask = group == igroup
ws[mask] = 1.0 / num.sqrt(num.nansum(ns[mask]**2)) ws[mask] = 1.0 / sqrt(num.nansum(sqr(ns[mask])))
return ws return ws
def inter_group_weights2(self, ns): def inter_group_weights2(self, ns):
sqr, sqrt = self.get_sqr_sqrt()
group, ngroups = self.get_group_mask() group, ngroups = self.get_group_mask()
ws = num.zeros(ns.shape) ws = num.zeros(ns.shape)
for igroup in xrange(ngroups): for igroup in xrange(ngroups):
mask = group == igroup mask = group == igroup
ws[:, mask] = (1.0 / num.sqrt( ws[:, mask] = (1.0 / sqrt(
num.nansum(ns[:, mask]**2, axis=1)))[:, num.newaxis] num.nansum(sqr(ns[:, mask]), axis=1)))[:, num.newaxis]
return ws return ws
def bootstrap_misfit(self, ms, ns, ibootstrap=None): def bootstrap_misfit(self, ms, ns, ibootstrap=None):
sqr, sqrt = self.get_sqr_sqrt()
w = self.get_bootstrap_weights(ibootstrap) * \ w = self.get_bootstrap_weights(ibootstrap) * \
self.get_target_weights() * self.inter_group_weights(ns) self.get_target_weights() * self.inter_group_weights(ns)
if ibootstrap is None: if ibootstrap is None:
return num.sqrt( return sqrt(
num.nansum((w*ms[num.newaxis, :])**2, axis=1) / num.nansum(sqr(w*ms[num.newaxis, :]), axis=1) /
num.nansum((w*ns[num.newaxis, :])**2, axis=1)) num.nansum(sqr(w*ns[num.newaxis, :]), axis=1))
else: else:
return num.sqrt(num.nansum((w*ms)**2) / num.nansum((w*ns)**2)) return sqrt(num.nansum(sqr(w*ms)) / num.nansum(sqr(w*ns)))
def bootstrap_misfits(self, misfits, ibootstrap): def bootstrap_misfits(self, misfits, ibootstrap):
sqr, sqrt = self.get_sqr_sqrt()
w = self.get_bootstrap_weights(ibootstrap)[num.newaxis, :] * \ w = self.get_bootstrap_weights(ibootstrap)[num.newaxis, :] * \
self.get_target_weights()[num.newaxis, :] * \ self.get_target_weights()[num.newaxis, :] * \
self.inter_group_weights2(misfits[:, :, 1]) self.inter_group_weights2(misfits[:, :, 1])
bms = num.sqrt(num.nansum((w*misfits[:, :, 0])**2, axis=1) / bms = sqrt(num.nansum(sqr(w*misfits[:, :, 0]), axis=1) /
num.nansum((w*misfits[:, :, 1])**2, axis=1)) num.nansum(sqr(w*misfits[:, :, 1]), axis=1))
return bms return bms
def global_misfit(self, ms, ns): def global_misfit(self, ms, ns):
sqr, sqrt = self.get_sqr_sqrt()
ws = self.get_target_weights() * self.inter_group_weights(ns) ws = self.get_target_weights() * self.inter_group_weights(ns)
m = num.sqrt(num.nansum((ws*ms)**2) / num.nansum((ws*ns)**2)) m = sqrt(num.nansum(sqr(ws*ms)) / num.nansum(sqr(ws*ns)))
return m return m
def global_misfits(self, misfits): def global_misfits(self, misfits):
sqr, sqrt = self.get_sqr_sqrt()
ws = self.get_target_weights()[num.newaxis, :] * \ ws = self.get_target_weights()[num.newaxis, :] * \
self.inter_group_weights2(misfits[:, :, 1]) self.inter_group_weights2(misfits[:, :, 1])
gms = num.sqrt(num.nansum((ws*misfits[:, :, 0])**2, axis=1) / gms = sqrt(num.nansum(sqr(ws*misfits[:, :, 0]), axis=1) /
num.nansum((ws*misfits[:, :, 1])**2, axis=1)) num.nansum(sqr(ws*misfits[:, :, 1]), axis=1))
return gms return gms
def global_contributions(self, misfits): def global_contributions(self, misfits):
sqr, sqrt = self.get_sqr_sqrt()
ws = self.get_target_weights()[num.newaxis, :] * \ ws = self.get_target_weights()[num.newaxis, :] * \
self.inter_group_weights2(misfits[:, :, 1]) self.inter_group_weights2(misfits[:, :, 1])
gcms = (ws*misfits[:, :, 0])**2 / \ gcms = sqr(ws*misfits[:, :, 0]) / \
num.nansum((ws*misfits[:, :, 1])**2, axis=1)[:, num.newaxis] num.nansum(sqr(ws*misfits[:, :, 1]), axis=1)[:, num.newaxis]
return gcms return gcms
...@@ -319,6 +350,7 @@ class CMTProblemConfig(core.ProblemConfig): ...@@ -319,6 +350,7 @@ class CMTProblemConfig(core.ProblemConfig):
problem = CMTProblem( problem = CMTProblem(
name=core.expand_template(self.name_template, subs), name=core.expand_template(self.name_template, subs),
apply_balancing_weights=self.apply_balancing_weights, apply_balancing_weights=self.apply_balancing_weights,
norm_exponent=self.norm_exponent,
base_source=base_source, base_source=base_source,
targets=targets, targets=targets,
ranges=self.ranges, ranges=self.ranges,
......
...@@ -124,6 +124,7 @@ class Problem(Object): ...@@ -124,6 +124,7 @@ class Problem(Object):
parameters = List.T(Parameter.T()) parameters = List.T(Parameter.T())
dependants = List.T(Parameter.T()) dependants = List.T(Parameter.T())
apply_balancing_weights = Bool.T(default=True) apply_balancing_weights = Bool.T(default=True)
norm_exponent = Int.T(default=2)
base_source = gf.Source.T(optional=True) base_source = gf.Source.T(optional=True)
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -254,6 +255,7 @@ class Problem(Object): ...@@ -254,6 +255,7 @@ class Problem(Object):
class ProblemConfig(Object): class ProblemConfig(Object):
name_template = String.T() name_template = String.T()
apply_balancing_weights = Bool.T(default=True) apply_balancing_weights = Bool.T(default=True)
norm_exponent = Int.T(default=2)
class Forbidden(Exception): class Forbidden(Exception):
...@@ -302,6 +304,10 @@ class InnerMisfitConfig(Object): ...@@ -302,6 +304,10 @@ class InnerMisfitConfig(Object):
help='Type of data characteristic to be fitted.\n\nAvailable choices ' help='Type of data characteristic to be fitted.\n\nAvailable choices '
'are: %s' % ', '.join("``'%s'``" % s 'are: %s' % ', '.join("``'%s'``" % s
for s in DomainChoice.choices)) for s in DomainChoice.choices))
norm_exponent = Int.T(
default=2,
help='Exponent to use in norm (1: L1-norm, 2: L2-norm)')
tautoshift_max = Float.T( tautoshift_max = Float.T(
default=0.0, default=0.0,
help='If non-zero, allow synthetic and observed traces to be shifted ' help='If non-zero, allow synthetic and observed traces to be shifted '
...@@ -505,7 +511,7 @@ class MisfitTarget(gf.Target): ...@@ -505,7 +511,7 @@ class MisfitTarget(gf.Target):
tmax_fit, tmax_fit,
tmax_fit + tfade_taper), tmax_fit + tfade_taper),
domain=config.domain, domain=config.domain,
exponent=2, exponent=config.norm_exponent,
flip=self.flip_norm, flip=self.flip_norm,
result_mode=self._result_mode, result_mode=self._result_mode,
tautoshift_max=config.tautoshift_max, tautoshift_max=config.tautoshift_max,
......
...@@ -682,7 +682,7 @@ def draw_contributions_figure(model, plt): ...@@ -682,7 +682,7 @@ def draw_contributions_figure(model, plt):
imodels = num.arange(model.nmodels) imodels = num.arange(model.nmodels)
gms = problem.global_misfits(model.misfits)**2 gms = problem.global_misfits(model.misfits)**problem.norm_exponent
isort = num.argsort(gms)[::-1] isort = num.argsort(gms)[::-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment