Commit 3dda57ff authored by Sebastian Heimann's avatar Sebastian Heimann
Browse files

restructure problem.evaluate

parent ccbb734c
...@@ -76,7 +76,7 @@ class Analyser(object): ...@@ -76,7 +76,7 @@ class Analyser(object):
isok_mask = num.logical_not(isbad_mask) isok_mask = num.logical_not(isbad_mask)
else: else:
isok_mask = None isok_mask = None
ms = wproblem.evaluate(x, mask=isok_mask)[:, 1] ms = wproblem.misfits(x, mask=isok_mask)[:, 1]
mss[iiter, :] = ms mss[iiter, :] = ms
isbad_mask = num.isnan(ms) isbad_mask = num.isnan(ms)
......
...@@ -250,7 +250,7 @@ def forward(rundir_or_config_path, event_names): ...@@ -250,7 +250,7 @@ def forward(rundir_or_config_path, event_names):
events = [] events = []
for (problem, x) in payload: for (problem, x) in payload:
ds.empty_cache() ds.empty_cache()
_, results = problem.evaluate(x, result_mode='full') results = problem.evaluate(x)
event = problem.get_source(x).pyrocko_event() event = problem.get_source(x).pyrocko_event()
events.append(event) events.append(event)
...@@ -376,7 +376,7 @@ def check( ...@@ -376,7 +376,7 @@ def check(
if n_random_synthetics == 0: if n_random_synthetics == 0:
x = problem.pack(problem.base_source) x = problem.pack(problem.base_source)
sources.append(problem.base_source) sources.append(problem.base_source)
_, results = problem.evaluate(x, result_mode='full') results = problem.evaluate(x)
results_list.append(results) results_list.append(results)
else: else:
...@@ -391,7 +391,7 @@ def check( ...@@ -391,7 +391,7 @@ def check(
pass pass
sources.append(problem.get_source(x)) sources.append(problem.get_source(x))
_, results = problem.evaluate(x, result_mode='full') results = problem.evaluate(x)
results_list.append(results) results_list.append(results)
if show_waveforms: if show_waveforms:
......
...@@ -435,7 +435,7 @@ class HighScoreOptimizer(Optimizer): ...@@ -435,7 +435,7 @@ class HighScoreOptimizer(Optimizer):
else: else:
isok_mask = None isok_mask = None
misfits = problem.evaluate(x, mask=isok_mask) misfits = problem.misfits(x, mask=isok_mask)
isbad_mask_new = num.isnan(misfits[:, 0]) isbad_mask_new = num.isnan(misfits[:, 0])
if isbad_mask is not None and num.any( if isbad_mask is not None and num.any(
......
...@@ -903,7 +903,7 @@ def draw_fits_figures_statics(ds, history, optimizer, plt): ...@@ -903,7 +903,7 @@ def draw_fits_figures_statics(ds, history, optimizer, plt):
source = problem.get_source(xbest) source = problem.get_source(xbest)
_, results = problem.evaluate(xbest, result_mode='full') results = problem.evaluate(xbest)
figures = [] figures = []
...@@ -1059,7 +1059,7 @@ def draw_fits_ensemble_figures( ...@@ -1059,7 +1059,7 @@ def draw_fits_ensemble_figures(
model = models[imodel, :] model = models[imodel, :]
source = problem.get_source(model) source = problem.get_source(model)
_, results = problem.evaluate(model, result_mode='full') results = problem.evaluate(model)
dtraces.append([]) dtraces.append([])
...@@ -1420,7 +1420,7 @@ def draw_fits_figures(ds, history, optimizer, plt): ...@@ -1420,7 +1420,7 @@ def draw_fits_figures(ds, history, optimizer, plt):
target_to_result = {} target_to_result = {}
all_syn_trs = [] all_syn_trs = []
all_syn_specs = [] all_syn_specs = []
_, results = problem.evaluate(xbest, result_mode='full') results = problem.evaluate(xbest)
dtraces = [] dtraces = []
for target, result in zip(problem.waveform_targets, results): for target, result in zip(problem.waveform_targets, results):
......
...@@ -11,8 +11,8 @@ from pyrocko import gf, util, guts ...@@ -11,8 +11,8 @@ from pyrocko import gf, util, guts
from pyrocko.guts import Object, String, Bool, List, Dict, Int from pyrocko.guts import Object, String, Bool, List, Dict, Int
from ..meta import ADict, Parameter, GrondError, xjoin from ..meta import ADict, Parameter, GrondError, xjoin
from ..targets import MisfitTarget, TargetGroup, WaveformMisfitTarget, \ from ..targets import MisfitResult, MisfitTarget, TargetGroup, \
SatelliteMisfitTarget WaveformMisfitTarget, SatelliteMisfitTarget
guts_prefix = 'grond' guts_prefix = 'grond'
...@@ -353,7 +353,7 @@ class Problem(Object): ...@@ -353,7 +353,7 @@ class Problem(Object):
return self._family_mask return self._family_mask
def evaluate(self, x, mask=None, result_mode='sparse'): def evaluate(self, x, mask=None, result_mode='full'):
source = self.get_source(x) source = self.get_source(x)
engine = self.get_engine() engine = self.get_engine()
...@@ -371,16 +371,12 @@ class Problem(Object): ...@@ -371,16 +371,12 @@ class Problem(Object):
modelling_results = list(resp.results_list[0]) modelling_results = list(resp.results_list[0])
imt = 0 imt = 0
imisfit = 0
misfits = num.zeros((self.nmisfits, 2))
misfits.fill(None)
results = [] results = []
for itarget, target in enumerate(self.targets): for itarget, target in enumerate(self.targets):
nmt_this = len(t2m_map[target]) nmt_this = len(t2m_map[target])
if mask is None or mask[itarget]: if mask is None or mask[itarget]:
misfits[imisfit:imisfit+target.nmisfits, :], result = \ result = target.finalize_modelling(
target.finalize_modelling( modelling_results[imt:imt+nmt_this])
modelling_results[imt:imt+nmt_this])
imt += nmt_this imt += nmt_this
else: else:
...@@ -388,12 +384,21 @@ class Problem(Object): ...@@ -388,12 +384,21 @@ class Problem(Object):
'target was excluded from modelling') 'target was excluded from modelling')
results.append(result) results.append(result)
return results
def misfits(self, x, mask=None):
results = self.evaluate(x, mask=mask, result_mode='sparse')
imisfit = 0
misfits = num.zeros((self.nmisfits, 2))
misfits.fill(None)
for target, result in zip(self.targets, results):
if isinstance(result, MisfitResult):
misfits[imisfit:imisfit+target.nmisfits, :] = result.misfits
imisfit += target.nmisfits imisfit += target.nmisfits
if result_mode == 'full': return misfits
return misfits, results
else:
return misfits
class InvalidRundir(Exception): class InvalidRundir(Exception):
......
...@@ -3,6 +3,7 @@ import copy ...@@ -3,6 +3,7 @@ import copy
import numpy as num import numpy as num
from pyrocko import gf from pyrocko import gf
from pyrocko.guts_array import Array
from pyrocko.guts import Object, Float from pyrocko.guts import Object, Float
...@@ -29,7 +30,9 @@ class TargetAnalysisResult(Object): ...@@ -29,7 +30,9 @@ class TargetAnalysisResult(Object):
class MisfitResult(Object): class MisfitResult(Object):
pass misfits = Array.T(
shape=(None, 2),
dtype=num.float)
class MisfitTarget(Object): class MisfitTarget(Object):
...@@ -100,7 +103,7 @@ class MisfitTarget(Object): ...@@ -100,7 +103,7 @@ class MisfitTarget(Object):
def init_modelling(self): def init_modelling(self):
return [] return []
def finalize_modelling(self, results): def finalize_modelling(self, modelling_results):
raise NotImplemented('must be overloaded in subclass') raise NotImplemented('must be overloaded in subclass')
......
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import numpy as num import numpy as num
from pyrocko import gf from pyrocko import gf
from pyrocko.guts import String, Bool, Dict, List, Object, Float from pyrocko.guts import String, Bool, Dict, List, Object
from grond.meta import Parameter from grond.meta import Parameter
...@@ -67,8 +67,6 @@ class SatelliteTargetGroup(TargetGroup): ...@@ -67,8 +67,6 @@ class SatelliteTargetGroup(TargetGroup):
class SatelliteMisfitResult(gf.Result, MisfitResult): class SatelliteMisfitResult(gf.Result, MisfitResult):
misfit_value = Float.T()
misfit_norm = Float.T()
statics_syn = Dict.T(optional=True) statics_syn = Dict.T(optional=True)
statics_obs = Dict.T(optional=True) statics_obs = Dict.T(optional=True)
...@@ -101,8 +99,7 @@ class SatelliteMisfitTarget(gf.SatelliteTarget, MisfitTarget): ...@@ -101,8 +99,7 @@ class SatelliteMisfitTarget(gf.SatelliteTarget, MisfitTarget):
self._target_ranges.pop(k) self._target_ranges.pop(k)
return self._target_ranges return self._target_ranges
@property def string_id(self):
def id(self):
return self.scene_id return self.scene_id
def set_dataset(self, ds): def set_dataset(self, ds):
...@@ -135,8 +132,7 @@ class SatelliteMisfitTarget(gf.SatelliteTarget, MisfitTarget): ...@@ -135,8 +132,7 @@ class SatelliteMisfitTarget(gf.SatelliteTarget, MisfitTarget):
num.sum((stat_obs * scene.covariance.weight_vector)**2)) num.sum((stat_obs * scene.covariance.weight_vector)**2))
result = SatelliteMisfitResult( result = SatelliteMisfitResult(
misfit_value=misfit_value, misfits=num.array([[misfit_value, misfit_norm]], dtype=num.float))
misfit_norm=misfit_norm)
if self._result_mode == 'full': if self._result_mode == 'full':
result.statics_syn = statics result.statics_syn = statics
......
...@@ -149,29 +149,10 @@ class WaveformTargetGroup(TargetGroup): ...@@ -149,29 +149,10 @@ class WaveformTargetGroup(TargetGroup):
targets.append(target) targets.append(target)
if self.limit: if self.limit:
return self.weed(origin, targets, self.limit)[0] return weed(origin, targets, self.limit)[0]
else: else:
return targets return targets
@staticmethod
def weed(origin, targets, limit, neighborhood=3):
azimuths = num.zeros(len(targets))
dists = num.zeros(len(targets))
for i, target in enumerate(targets):
_, azimuths[i] = target.azibazi_to(origin)
dists[i] = target.distance_to(origin)
badnesses = num.ones(len(targets), dtype=float)
deleted, meandists_kept = weeding.weed(
azimuths, dists, badnesses,
nwanted=limit,
neighborhood=neighborhood)
targets_weeded = [
target for (delete, target) in zip(deleted, targets) if not delete]
return targets_weeded, meandists_kept, deleted
class TraceSpectrum(Object): class TraceSpectrum(Object):
network = String.T() network = String.T()
...@@ -190,8 +171,6 @@ class TraceSpectrum(Object): ...@@ -190,8 +171,6 @@ class TraceSpectrum(Object):
class WaveformMisfitResult(gf.Result, MisfitResult): class WaveformMisfitResult(gf.Result, MisfitResult):
misfit_value = Float.T()
misfit_norm = Float.T()
processed_obs = Trace.T(optional=True) processed_obs = Trace.T(optional=True)
processed_syn = Trace.T(optional=True) processed_syn = Trace.T(optional=True)
filtered_obs = Trace.T(optional=True) filtered_obs = Trace.T(optional=True)
...@@ -217,15 +196,6 @@ class WaveformMisfitTarget(gf.Target, MisfitTarget): ...@@ -217,15 +196,6 @@ class WaveformMisfitTarget(gf.Target, MisfitTarget):
def string_id(self): def string_id(self):
return '.'.join(x for x in (self.path,) + self.codes if x) return '.'.join(x for x in (self.path,) + self.codes if x)
@property
def id(self):
return '.'.join(self.codes)
def get_plain_modelling_targets(self):
d = dict(
(k, getattr(self, k)) for k in gf.Target.T.propnames)
return [gf.Target(**d)]
def get_combined_weight(self, apply_balancing_weights): def get_combined_weight(self, apply_balancing_weights):
w = self.manual_weight w = self.manual_weight
if apply_balancing_weights: if apply_balancing_weights:
...@@ -369,16 +339,13 @@ class WaveformMisfitTarget(gf.Target, MisfitTarget): ...@@ -369,16 +339,13 @@ class WaveformMisfitTarget(gf.Target, MisfitTarget):
def prepare_modelling(self): def prepare_modelling(self):
return [self] return [self]
def finalize_modelling(self, results): def finalize_modelling(self, modelling_results):
result = results[0] return modelling_results[0]
if isinstance(result, gf.SeismosizerError):
misfits = num.array(
[[None, None]], dtype=num.float)
else: else:
misfits = num.array( return targets
[[result.misfit_value, result.misfit_norm]], dtype=num.float)
return misfits, result
def misfit( def misfit(
...@@ -472,8 +439,7 @@ tautoshift**2 / tautoshift_max**2`` ...@@ -472,8 +439,7 @@ tautoshift**2 / tautoshift_max**2``
if result_mode == 'full': if result_mode == 'full':
result = WaveformMisfitResult( result = WaveformMisfitResult(
misfit_value=float(m), misfits=num.array([[m, n]], dtype=num.float),
misfit_norm=float(n),
processed_obs=tr_proc_obs, processed_obs=tr_proc_obs,
processed_syn=tr_proc_syn, processed_syn=tr_proc_syn,
filtered_obs=tr_obs.copy(), filtered_obs=tr_obs.copy(),
...@@ -486,8 +452,7 @@ tautoshift**2 / tautoshift_max**2`` ...@@ -486,8 +452,7 @@ tautoshift**2 / tautoshift_max**2``
elif result_mode == 'sparse': elif result_mode == 'sparse':
result = WaveformMisfitResult( result = WaveformMisfitResult(
misfit_value=m, misfits=num.array([[m, n]], dtype=num.float))
misfit_norm=n)
else: else:
assert False assert False
...@@ -568,6 +533,25 @@ def float_or_none(x): ...@@ -568,6 +533,25 @@ def float_or_none(x):
return float(x) return float(x)
def weed(origin, targets, limit, neighborhood=3):
azimuths = num.zeros(len(targets))
dists = num.zeros(len(targets))
for i, target in enumerate(targets):
_, azimuths[i] = target.azibazi_to(origin)
dists[i] = target.distance_to(origin)
badnesses = num.ones(len(targets), dtype=float)
deleted, meandists_kept = weeding.weed(
azimuths, dists, badnesses,
nwanted=limit,
neighborhood=neighborhood)
targets_weeded = [
target for (delete, target) in zip(deleted, targets) if not delete]
return targets_weeded, meandists_kept, deleted
__all__ = ''' __all__ = '''
WaveformTargetGroup WaveformTargetGroup
WaveformMisfitConfig WaveformMisfitConfig
......
...@@ -93,7 +93,7 @@ class ToyProblem(Problem): ...@@ -93,7 +93,7 @@ class ToyProblem(Problem):
[t.obs_distance for t in self.targets], [t.obs_distance for t in self.targets],
dtype=num.float) dtype=num.float)
def evaluate(self, x, mask=None): def misfits(self, x, mask=None):
self._setup_modelling() self._setup_modelling()
distances = num.sqrt( distances = num.sqrt(
num.sum((x[num.newaxis, :]-self._xtargets)**2, axis=1)) num.sum((x[num.newaxis, :]-self._xtargets)**2, axis=1))
...@@ -104,7 +104,7 @@ class ToyProblem(Problem): ...@@ -104,7 +104,7 @@ class ToyProblem(Problem):
* num.mean(num.abs(self._obs_distances)) * num.mean(num.abs(self._obs_distances))
return misfits return misfits
def evaluate_many(self, xs): def misfits_many(self, xs):
self._setup_modelling() self._setup_modelling()
distances = num.sqrt( distances = num.sqrt(
num.sum( num.sum(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment