Commit 40bd715d authored by Sebastian Heimann's avatar Sebastian Heimann
Browse files

refactored misfit calculation, improved fits plots

parent 3891d012
......@@ -167,8 +167,6 @@ class CMTProblem(core.Problem):
def evaluate(self, x, return_traces=False):
source = self.unpack(x)
engine = self.get_engine()
for target in self.targets:
target.set_return_traces(return_traces)
resp = engine.process(source, self.targets)
data = []
......
......@@ -13,6 +13,7 @@ from pyrocko.guts import load, Object, String, Float, Int, Bool, List, \
StringChoice, Dict, Timestamp
from pyrocko import orthodrome as od, gf, trace, guts, util, weeding
from pyrocko import parimap, model, gui_util
from pyrocko.guts_array import Array
from grond import dataset
......@@ -21,6 +22,27 @@ logger = logging.getLogger('grond.core')
guts_prefix = 'grond'
def float_or_none(x):
if x is None:
return x
else:
return float(x)
class Trace(Object):
pass
class TraceSpectrum(Object):
network = String.T()
station = String.T()
location = String.T()
channel = String.T()
deltaf = Float.T(default=1.0)
fmin = Float.T(default=0.0)
ydata = Array.T(shape=(None,), dtype=num.complex, serialize_as='list')
def mahalanobis_distance(xs, mx, cov):
imask = num.diag(cov) != 0.
icov = num.linalg.inv(cov[imask, :][:, imask])
......@@ -201,6 +223,15 @@ class GrondError(Exception):
pass
class DomainChoice(StringChoice):
choices = [
'time_domain',
'frequency_domain',
'envelope',
'absolute',
'cc_max_norm']
class InnerMisfitConfig(Object):
fmin = Float.T()
fmax = Float.T()
......@@ -209,7 +240,7 @@ class InnerMisfitConfig(Object):
tmax = gf.Timing.T()
pick_synthetic_traveltime = gf.Timing.T(optional=True)
pick_phasename = String.T(optional=True)
domain = trace.DomainChoice.T(default='time_domain')
domain = DomainChoice.T(default='time_domain')
class TargetAnalysisResult(Object):
......@@ -220,17 +251,37 @@ class NoAnalysisResults(Exception):
pass
class MisfitResult(gf.Result):
misfit_value = Float.T()
misfit_norm = Float.T()
processed_obs = Trace.T(optional=True)
processed_syn = Trace.T(optional=True)
filtered_obs = Trace.T(optional=True)
filtered_syn = Trace.T(optional=True)
spectrum_obs = TraceSpectrum.T(optional=True)
spectrum_syn = TraceSpectrum.T(optional=True)
taper = trace.Taper.T(optional=True)
tobs_shift = Float.T(optional=True)
tsyn_pick = Timestamp.T(optional=True)
cc_shift = Float.T(optional=True)
cc = Trace.T(optional=True)
class MisfitTarget(gf.Target):
misfit_config = InnerMisfitConfig.T()
flip_norm = Bool.T(default=False)
manual_weight = Float.T(default=1.0)
analysis_result = TargetAnalysisResult.T(optional=True)
groupname = gf.StringID.T(optional=True)
super_group = gf.StringID.T()
group = gf.StringID.T()
def __init__(self, **kwargs):
gf.Target.__init__(self, **kwargs)
self._ds = None
self._return_traces = False
def string_id(self):
return '.'.join(x for x in (
self.super_group, self.group) + self.codes if x)
def get_plain_target(self):
d = dict(
......@@ -243,9 +294,6 @@ class MisfitTarget(gf.Target):
def set_dataset(self, ds):
self._ds = ds
def set_return_traces(self, return_traces):
self._return_traces = return_traces
def get_combined_weight(self, apply_balancing_weights):
w = self.manual_weight
if apply_balancing_weights:
......@@ -339,56 +387,147 @@ class MisfitTarget(gf.Target):
tr_obs = tr_obs.copy()
tr_obs.shift(-tobs_shift)
ms = trace.MisfitSetup(
norm=2,
domain=config.domain,
filter=trace.PoleZeroResponse(),
mr = misfit(
tr_obs, tr_syn,
taper=trace.CosTaper(
tmin_fit - tfade,
tmin_fit,
tmax_fit,
tmax_fit + tfade))
if not self.flip_norm:
mv, mn, tr_obs_proc, tr_syn_proc = tr_obs.misfit(
tr_syn, ms, nocache=True, debug=True)
else:
mv, mn, tr_syn_proc, tr_obs_proc = tr_syn.misfit(
tr_obs, ms, nocache=True, debug=True)
result = MisfitResult(
misfit_value=float(mv), misfit_norm=float(mn))
tmax_fit + tfade),
domain=config.domain,
exponent=2,
flip=self.flip_norm)
if self._return_traces:
result.filtered_obs = tr_obs
result.filtered_syn = tr_syn
result.processed_obs = tr_obs_proc
result.processed_syn = tr_syn_proc
result.taper = ms.taper
result.tobs_shift = float(tobs_shift)
result.tsyn_pick = float(tsyn)
mr.tobs_shift = float(tobs_shift)
mr.tsyn_pick = float_or_none(tsyn)
return result
return mr
except dataset.NotFound, e:
logger.debug(str(e))
raise gf.SeismosizerError('no waveform data, %s' % str(e))
class Trace(Object):
pass
def misfit(tr_obs, tr_syn, taper, domain, exponent, flip):
'''
Calculate misfit between observed and synthetic trace.
:param tr_obs: observed trace as :py:class:`pyrocko.trace.Trace`
:param tr_syn: synthetic trace as :py:class:`pyrocko.trace.Trace`
:param taper: taper applied in timedomain as
:py:class:`pyrocko.trace.Taper`
:param domain: how to calculate difference, see :py:class:`DomainChoice`
:param exponent: exponent of Lx type norms
:param flip: ``bool``, if set to ``True``, normalization factor is
computed against *tr_syn* rather than *tr_obs*
:returns: object of type :py:class:`MisfitResult`
'''
class MisfitResult(gf.Result):
misfit_value = Float.T()
misfit_norm = Float.T()
processed_obs = Trace.T(optional=True)
processed_syn = Trace.T(optional=True)
filtered_obs = Trace.T(optional=True)
filtered_syn = Trace.T(optional=True)
taper = trace.Taper.T(optional=True)
tobs_shift = Float.T(optional=True)
tsyn_pick = Timestamp.T(optional=True)
trace.assert_same_sampling_rate(tr_obs, tr_syn)
tmin, tmax = taper.time_span()
tr_proc_obs, trspec_proc_obs = _process(tr_obs, tmin, tmax, taper, domain)
tr_proc_syn, trspec_proc_syn = _process(tr_syn, tmin, tmax, taper, domain)
cc_shift = None
ctr = None
if domain in ('time_domain', 'envelope', 'absolute'):
a, b = tr_proc_syn.ydata, tr_proc_obs.ydata
if flip:
b, a = a, b
m, n = trace.Lx_norm(a, b, norm=exponent)
elif domain == 'cc_max_norm':
ctr = trace.correlate(
tr_proc_syn,
tr_proc_obs,
mode='same',
normalization='normal')
cc_shift, cc_max = ctr.max()
m = 0.5 - 0.5 * cc_max
n = 0.5
elif domain == 'frequency_domain':
a, b = trspec_proc_syn.ydata, trspec_proc_obs.ydata
if flip:
b, a = a, b
m, n = trace.Lx_norm(num.abs(a), num.abs(b), norm=exponent)
result = MisfitResult(
misfit_value=m,
misfit_norm=n,
processed_obs=tr_proc_obs,
processed_syn=tr_proc_syn,
filtered_obs=tr_obs,
filtered_syn=tr_syn,
spectrum_obs=trspec_proc_obs,
spectrum_syn=trspec_proc_syn,
taper=taper,
cc_shift=cc_shift,
cc=ctr)
return result
def _process(tr, tmin, tmax, taper, domain):
tr_proc = _extend_extract(tr, tmin, tmax)
tr_proc.taper(taper)
spectrum = None
df = None
if domain == 'envelope':
tr_proc = tr_proc.envelope(inplace=False)
elif domain == 'absolute':
tr_proc.set_ydata(num.abs(tr_proc.get_ydata()))
elif domain == 'frequency_domain':
ndata = tr_proc.ydata.size
nfft = trace.nextpow2(ndata)
padded = num.zeros(nfft, dtype=num.float)
padded[:ndata] = tr_proc.ydata
spectrum = num.fft.rfft(padded)
df = 1.0 / (tr_proc.deltat * nfft)
trspec_proc = TraceSpectrum(
network=tr.network,
station=tr.station,
location=tr.location,
channel=tr.channel,
deltaf=df,
fmin=0.0,
ydata=spectrum)
return tr_proc, trspec_proc
def _extend_extract(tr, tmin, tmax):
deltat = tr.deltat
itmin_frame = int(math.floor(tmin/deltat))
itmax_frame = int(math.ceil(tmax/deltat))
nframe = itmax_frame - itmin_frame
n = tr.data_len()
a = num.empty(nframe, dtype=num.float)
itmin_tr = int(round(tr.tmin / deltat))
itmax_tr = itmin_tr + n
icut1 = min(max(0, itmin_tr - itmin_frame), nframe)
icut2 = min(max(0, itmax_tr - itmin_frame), nframe)
icut1_tr = min(max(0, icut1 + itmin_frame - itmin_tr), n)
icut2_tr = min(max(0, icut2 + itmin_frame - itmin_tr), n)
a[:icut1] = tr.ydata[0]
a[icut1:icut2] = tr.ydata[icut1_tr:icut2_tr]
a[icut2:] = tr.ydata[-1]
tr = tr.copy(data=False)
tr.tmin = tmin
tr.set_ydata(a)
return tr
def xjoin(basepath, path):
......@@ -660,7 +799,8 @@ def weed(origin, targets, limit, neighborhood=3):
class TargetConfig(Object):
groupname = gf.StringID.T(optional=True)
super_group = gf.StringID.T(default='', optional=True)
group = gf.StringID.T(optional=True)
distance_min = Float.T(optional=True)
distance_max = Float.T(optional=True)
limit = Int.T(optional=True)
......@@ -670,7 +810,7 @@ class TargetConfig(Object):
store_id = gf.StringID.T()
weight = Float.T(default=1.0)
def get_targets(self, ds, event, default_groupname):
def get_targets(self, ds, event, default_group):
origin = event
......@@ -682,11 +822,13 @@ class TargetConfig(Object):
codes=st.nsl() + (cha,),
lat=st.lat,
lon=st.lon,
depth=st.depth,
interpolation=self.interpolation,
store_id=self.store_id,
misfit_config=self.inner_misfit_config,
manual_weight=self.weight,
groupname=self.groupname or default_groupname)
super_group=self.super_group,
group=self.group or default_group)
if self.distance_min is not None and \
target.distance_to(origin) < self.distance_min:
......@@ -899,6 +1041,18 @@ def analyse(problem, niter=1000, show_progress=False):
wtarget.weight = 1.0
wtargets.append(wtarget)
super_group_names = set()
groups = num.zeros(len(problem.targets), dtype=num.int)
ngroups = 0
for itarget, target in enumerate(problem.targets):
if target.super_group not in super_group_names:
super_group_names.add(target.super_group)
ngroups += 1
groups[itarget] = ngroups - 1
ngroups += 1
wproblem = problem.copy()
wproblem.targets = wtargets
......@@ -907,6 +1061,7 @@ def analyse(problem, niter=1000, show_progress=False):
mss = num.zeros((niter, problem.ntargets))
rstate = num.random.RandomState(123)
print groups
if show_progress:
pbar = util.progressbar('analysing problem', niter)
......@@ -936,8 +1091,12 @@ def analyse(problem, niter=1000, show_progress=False):
mean_ms = num.mean(mss, axis=0)
weights = 1.0 / mean_ms
weights /= (num.nansum(weights)/num.nansum(num.isfinite(weights)))
for igroup in xrange(ngroups):
weights[groups == igroup] /= (
num.nansum(weights[groups == igroup]) /
num.nansum(num.isfinite(weights[groups == igroup])))
for weight, target in zip(weights, problem.targets):
target.analysis_result = TargetAnalysisResult(
......@@ -1380,9 +1539,7 @@ def check(config, event_names=None):
fig = plt.figure()
axes = fig.add_subplot(1, 1, 1)
axes.set_ylim(0., 4.)
axes.set_title('%s %s' % (
'.'.join(x for x in target.codes if x),
target.groupname))
axes.set_title('%s' % target.string_id())
xdata = result.filtered_obs.get_xdata()
ydata = result.filtered_obs.get_ydata() / yabsmax
......
......@@ -378,7 +378,6 @@ class Dataset(object):
tr.deltat = deltat
resp = self.get_response(tr)
print resp
return tr.transfer(tfade=tfade, freqlimits=freqlimits,
transfer_function=resp, invert=True)
......@@ -458,6 +457,10 @@ class Dataset(object):
out_channels=('R', 'T', 'Z'),
backazimuth=backazimuth))
if not mios:
raise NotFound(
'cannot determine projection of data components')
try:
trs_projected = []
for matrix, in_channels, out_channels in mios:
......
......@@ -713,8 +713,7 @@ def draw_contributions_figure(model, plt):
poly_x, rel_poly_y,
alpha=0.5,
color=colors[ii % len(colors)],
label='%s.%s.%s.%s.%s (%.2g)' % (
target.codes + (target.groupname, num.mean(rel_ms[-1]),)))
label='%s (%.2g)' % (target.string_id, num.mean(rel_ms[-1])))
poly_y = num.concatenate(
[ms_smooth_sum[::-1], ms_smooth_sum + ms_smooth])
......@@ -814,17 +813,22 @@ def plot_taper(axes, t, taper, **kwargs):
axes.fill(t2, y2, **kwargs)
def plot_dtrace(axes, tr, **kwargs):
def plot_dtrace(axes, tr, space, mi, ma, **kwargs):
t = tr.get_xdata()
y = tr.get_ydata()
y2 = num.concatenate(((y*0.2), num.zeros(y.size))) - 1.0
y2 = (num.concatenate((y, num.zeros(y.size))) - mi) / \
(ma-mi) * space - (1.0 + space)
t2 = num.concatenate((t, t[::-1]))
return axes.fill(
axes.fill(
t2, y2,
clip_on=False,
**kwargs)
def plot_dtrace_vline(axes, t, space, **kwargs):
axes.plot([t, t], [-1.0 - space, -1.0], **kwargs)
def draw_fits_figures(ds, model, plt):
fontsize = 10
......@@ -858,7 +862,6 @@ def draw_fits_figures(ds, model, plt):
dtraces = []
for target, result in zip(problem.targets, results):
print target.misfit_config.domain
if result is None:
dtraces.append(None)
continue
......@@ -866,46 +869,72 @@ def draw_fits_figures(ds, model, plt):
itarget = target_index[target]
w = target.get_combined_weight(problem.apply_balancing_weights)
if target.misfit_config.domain != 'time_domain':
dtraces.append(None)
continue
if target.misfit_config.domain == 'cc_max_norm':
tref = (result.filtered_obs.tmin + result.filtered_obs.tmax) * 0.5
for tr_filt, tr_proc, tshift in (
(result.filtered_obs,
result.processed_obs,
0.),
(result.filtered_syn,
result.processed_syn,
result.cc_shift)):
norm = num.sum(num.abs(tr_proc.ydata)) / tr_proc.data_len()
tr_filt.ydata /= norm
tr_proc.ydata /= norm
tr_filt.shift(tshift)
tr_proc.shift(tshift)
ctr = result.cc
ctr.shift(tref)
dtrace = ctr
else:
for tr in (
result.filtered_obs,
result.filtered_syn,
result.processed_obs,
result.processed_syn):
for tr in (
result.filtered_obs,
result.filtered_syn,
result.processed_obs,
result.processed_syn):
tr.ydata *= w
tr.ydata *= w
dtrace = result.processed_syn.copy()
dtrace.set_ydata(
(
(result.processed_syn.get_ydata() -
result.processed_obs.get_ydata())**2))
target_to_result[target] = result
dtrace = result.processed_syn.copy()
dtrace.set_ydata(
(
(result.processed_syn.get_ydata() -
result.processed_obs.get_ydata())**2))
dtrace.meta = dict(super_group=target.super_group)
dtraces.append(dtrace)
result.processed_syn.meta = dict(super_group=target.super_group)
all_syn_trs.append(result.processed_syn)
if not all_syn_trs:
logger.warn('no traces to show')
return
amin, amax = trace.minmax(all_syn_trs, lambda tr: None)[None]
aminmaxs = trace.minmax(
all_syn_trs,
lambda tr: tr.meta['super_group'])
dmin, dmax = trace.minmax(
[x for x in dtraces if x is not None], lambda tr: None)[None]
dminmaxs = trace.minmax(
[x for x in dtraces if x is not None],
lambda tr: tr.meta['super_group'])
for tr in dtraces:
if tr:
tr.ydata /= dmax
absmax = max(abs(amin), abs(amax))
dmin, dmax = dminmaxs[tr.meta['super_group']]
tr.ydata /= max(abs(dmin), abs(dmax))
cg_to_targets = gather(
problem.targets, lambda t:
(t.codes[3], t.groupname), filter=lambda t: t in target_to_result)
problem.targets,
lambda t: (t.super_group, t.group, t.codes[3]),
filter=lambda t: t in target_to_result)
cgs = sorted(cg_to_targets.keys())
......@@ -989,11 +1018,9 @@ def draw_fits_figures(ds, model, plt):
fig = figures[iyy, ixx]
target = frame_to_target[iy, ix]
print target
print target.misfit_config.domain
if target.misfit_config.domain != 'time_domain':
continue
amin, amax = aminmaxs[target.super_group]
absmax = max(abs(amin), abs(amax))
ny_this = min(ny, nymax)
nx_this = min(nx, nxmax)
......@@ -1001,12 +1028,18 @@ def draw_fits_figures(ds, model, plt):
axes2 = fig.add_subplot(ny_this, nx_this, i_this)
space = 0.5
space_factor = 1.0 + space
axes2.set_axis_off()
axes2.set_ylim(-1.05, 1.05)
axes2.set_ylim(-1.05 * space_factor, 1.05)
axes = axes2.twinx()
axes.set_axis_off()
axes.set_ylim(-absmax*1.33, absmax*1.33)
if target.misfit_config.domain == 'cc_max_norm':
axes.set_ylim(-10. * space_factor, 10.)
else:
axes.set_ylim(-absmax*1.33 * space_factor, absmax*1.33)
itarget = target_index[target]
result = target_to_result[target]
......@@ -1018,7 +1051,7 @@ def draw_fits_figures(ds, model, plt):
tap_color_fill = (0.95, 0.95, 0.90)
plot_taper(
axes2, result.processed_syn.get_xdata(), result.taper,
axes2, result.processed_obs.get_xdata(), result.taper,
fc=tap_color_fill, ec=tap_color_edge)
obs_color = scolor('aluminium5')
......@@ -1030,10 +1063,25 @@ def draw_fits_figures(ds, model, plt):
misfit_color = scolor('scarletred2')
weight_color = scolor('chocolate2')
plot_dtrace(
axes2, dtrace,
fc=light(misfit_color, 0.5),
ec=misfit_color)
cc_color = scolor('aluminium5')
if target.misfit_config.domain == 'cc_max_norm':
tref = (result.filtered_obs.tmin +
result.filtered_obs.tmax) * 0.5
plot_dtrace(
axes2, dtrace, space, -1., 1.,
fc=light(cc_color, 0.5),
ec=cc_color)
plot_dtrace_vline(
axes2, tref, space, color=tap_color_annot)
else:
plot_dtrace(
axes2, dtrace, space, 0., 1.,
fc=light(misfit_color, 0.5),
ec=misfit_color)
plot_trace(
axes, result.filtered_syn,
......
Supports Markdown
0% or .