Commit 5d1cbdbd authored by Sebastian Heimann's avatar Sebastian Heimann
Browse files

restructuring for better handling of event datasets and synthetic tests

parent cb04b480
......@@ -34,9 +34,11 @@ subcommand_descriptions = {
subcommand_usages = {
'init': 'init [options]',
'check': 'check <configfile> [options]',
'go': 'go <configfile> [options]',
'forward': 'forward <rundir> [options]',
'check': 'check <configfile> <eventnames> ... [options]',
'go': 'go <configfile> <eventnames> ... [options]',
'forward': (
'forward <rundir> [options]',
'forward <configfile> <eventnames> ... [options]'),
'harvest': 'harvest <rundir> [options]',
'plot': 'plot <plotnames> <rundir> [options]',
'export': 'export (best|mean|ensemble|stats) <rundirs> ... [options]',
......@@ -173,19 +175,15 @@ def command_init(args):
def command_check(args):
def setup(parser):
parser.add_option(
'--event', dest='event_name', metavar='EVENTNAME',
help='process only event EVENTNAME')
pass
parser, options, args = cl_parse('check', args, setup)
if len(args) != 1:
help_and_die(parser, 'no config file given')
event_names = None
if options.event_name:
event_names = [options.event_name]
if len(args) < 2:
help_and_die(parser, 'missing arguments')
config_path = args[0]
event_names = args[1:]
config = grond.read_config(config_path)
grond.check(
......@@ -195,9 +193,6 @@ def command_check(args):
def command_go(args):
def setup(parser):
parser.add_option(
'--event', dest='event_name', metavar='EVENTNAME',
help='process only event EVENTNAME')
parser.add_option(
'--force', dest='force', action='store_true',
help='overwrite existing run directory')
......@@ -209,20 +204,18 @@ def command_go(args):
help='set number of events to process in parallel')
parser, options, args = cl_parse('go', args, setup)
if len(args) != 1:
help_and_die(parser, 'no config file given')
if len(args) < 2:
help_and_die(parser, 'missing arguments')
config_path = args[0]
event_names = args[1:]
config = grond.read_config(config_path)
if options.status == 'quiet':
status = ()
else:
status = tuple(options.status.split(','))
event_names = None
if options.event_name:
event_names = [options.event_name]
grond.go(
config,
event_names=event_names,
......@@ -233,19 +226,15 @@ def command_go(args):
def command_forward(args):
def setup(parser):
parser.add_option(
'--event', dest='event_name', metavar='EVENTNAME',
help='process only event EVENTNAME')
pass
parser, options, args = cl_parse('forward', args, setup)
if len(args) != 1:
help_and_die(parser, 'incorrect number of arguments')
if len(args) < 1:
help_and_die(parser, 'missing arguments')
event_names = None
if options.event_name:
event_names = [options.event_name]
event_names = args[1:]
run_path, = args
run_path = args[0]
grond.forward(
run_path,
event_names=event_names)
......
......@@ -37,7 +37,6 @@ class CMTProblem(core.Problem):
core.Parameter('rel_moment_iso', label='$M_{0}^{ISO}/M_{0}$'),
core.Parameter('rel_moment_clvd', label='$M_{0}^{CLVD}/M_{0}$')]
base_source = gf.Source.T()
targets = List.T(core.MisfitTarget.T())
ranges = Dict.T(String.T(), gf.Range.T())
......@@ -303,7 +302,7 @@ class CMTProblemConfig(core.ProblemConfig):
event_time=util.time_to_str(event.time))
problem = CMTProblem(
name=self.name_template % subs,
name=core.substitute_template(self.name_template, subs),
apply_balancing_weights=self.apply_balancing_weights,
base_source=base_source,
targets=targets,
......
......@@ -6,6 +6,7 @@ import time
import copy
import shutil
import os.path as op
from string import Template
import numpy as num
......@@ -122,6 +123,7 @@ class Problem(Object):
parameters = List.T(Parameter.T())
dependants = List.T(Parameter.T())
apply_balancing_weights = Bool.T(default=True)
base_source = gf.Source.T()
def __init__(self, **kwargs):
Object.__init__(self, **kwargs)
......@@ -666,45 +668,27 @@ class SyntheticWaveformNotAvailable(Exception):
class SyntheticTest(Object):
random_seed = Int.T(default=0)
inject_solution = Bool.T(default=False)
ignore_data_availability = Bool.T(default=False)
respect_data_availability = Bool.T(default=False)
add_real_noise = Bool.T(default=False)
toffset_real_noise = Float.T(default=-3600.)
x = Dict.T(String.T(), Float.T())
def __init__(self, **kwargs):
Object.__init__(self, **kwargs)
self._problem = None
self._synthetics = None
self._rstate = num.random.RandomState(self.random_seed)
def set_config(self, config):
self._config = config
def set_problem(self, problem):
self._problem = problem
self._synthetics = None
def get_problem(self):
ds = self._config.get_dataset()
events = ds.get_events()
event = events[0]
return self._config.get_problem(event)
def get_x_random(self):
problem = self.get_problem()
xbounds = num.array(problem.bounds(), dtype=num.float)
npar = xbounds.shape[0]
x = num.zeros(npar, dtype=num.float)
while True:
for i in xrange(npar):
x[i] = self._rstate.uniform(xbounds[i, 0], xbounds[i, 1])
try:
x = problem.preconstrain(x)
break
except Forbidden:
pass
if self._problem is None:
raise SyntheticWaveformNotAvailable(
'SyntheticTest.set_problem() has not been called yet')
return x
return self._problem
def get_x(self):
problem = self.get_problem()
......@@ -713,22 +697,16 @@ class SyntheticTest(Object):
problem.parameter_array(self.x))
else:
print problem.base_source
x = problem.preconstrain(
problem.pack(
problem.base_source))
print x
print problem.unpack(x)
return x
def get_synthetics(self):
problem = self.get_problem()
if self._synthetics is None:
problem = self.get_problem()
x = self.get_x()
results = problem.forward(x)
self._synthetics = results
......@@ -775,12 +753,12 @@ class DatasetConfig(HasPaths):
def __init__(self, *args, **kwargs):
HasPaths.__init__(self, *args, **kwargs)
self._ds = None
self._ds = {}
def get_dataset(self):
if self._ds is None:
def get_dataset(self, event_name):
if event_name not in self._ds:
fp = self.expand_path
ds = dataset.Dataset()
ds = dataset.Dataset(event_name)
ds.add_stations(
pyrocko_stations_filename=fp(self.stations_path),
stationxml_filenames=fp(self.stations_stationxml_paths))
......@@ -813,10 +791,10 @@ class DatasetConfig(HasPaths):
if self.whitelist:
ds.add_whitelist(self.whitelist)
ds.set_synthetic_test(self.synthetic_test)
self._ds = ds
ds.set_synthetic_test(copy.deepcopy(self.synthetic_test))
self._ds[event_name] = ds
return self._ds
return self._ds[event_name]
def weed(origin, targets, limit, neighborhood=3):
......@@ -961,15 +939,11 @@ class Config(HasPaths):
def __init__(self, *args, **kwargs):
HasPaths.__init__(self, *args, **kwargs)
def get_dataset(self):
ds = self.dataset_config.get_dataset()
if ds.synthetic_test:
ds.synthetic_test.set_config(self)
return ds
def get_dataset(self, event_name):
return self.dataset_config.get_dataset(event_name)
def get_targets(self, event):
ds = self.get_dataset()
ds = self.get_dataset(event.name)
targets = []
for igroup, target_config in enumerate(self.target_configs):
......@@ -978,10 +952,18 @@ class Config(HasPaths):
return targets
def setup_modelling_environment(self, problem):
problem.set_engine(self.engine_config.get_engine())
ds = self.get_dataset(problem.base_source.name)
synt = ds.synthetic_test
if synt:
synt.set_problem(problem)
problem.base_source = problem.unpack(synt.get_x())
def get_problem(self, event):
targets = self.get_targets(event)
problem = self.problem_config.get_problem(event, targets)
problem.set_engine(self.engine_config.get_engine())
self.setup_modelling_environment(problem)
return problem
......@@ -1449,7 +1431,7 @@ def forward(rundir_or_config_path, event_names=None):
ibest = num.argmin(gms)
xbest = xs[ibest, :]
ds = config.get_dataset()
ds = config.get_dataset(problem.base_source.name)
problem.set_engine(config.engine_config.get_engine())
for target in problem.targets:
......@@ -1459,11 +1441,11 @@ def forward(rundir_or_config_path, event_names=None):
else:
config = read_config(rundir_or_config_path)
ds = config.get_dataset()
events = ds.get_events(event_names=event_names)
payload = []
for event in events:
for event_name in event_names:
ds = config.get_dataset(event_name)
event = ds.get_event()
problem = config.get_problem(event)
xref = problem.preconstrain(
problem.pack(problem.base_source))
......@@ -1556,17 +1538,12 @@ g_state = {}
def check(config, event_names=None):
ds = config.get_dataset()
events = ds.get_events(event_names=event_names)
nevents = len(events)
from matplotlib import pyplot as plt
from grond.plot import colors
if nevents == 0:
raise GrondError('no events found')
for ievent, event in enumerate(events):
for ievent, event_name in enumerate(event_names):
ds = config.get_dataset(event_name)
event = ds.get_event()
try:
problem = config.get_problem(event)
check_problem(problem)
......@@ -1651,18 +1628,12 @@ def go(config, event_names=None, force=False, nparallel=1, status=('state',)):
status = tuple(status)
ds = config.get_dataset()
events = ds.get_events(event_names=event_names)
nevents = len(events)
if nevents == 0:
raise GrondError('no events found')
g_data = (config, force, status, nparallel, event_names)
g_state[id(g_data)] = g_data
nevents = len(event_names)
for x in parimap.parimap(
process_event,
xrange(nevents),
......@@ -1672,6 +1643,17 @@ def go(config, event_names=None, force=False, nparallel=1, status=('state',)):
pass
def substitute_template(template, d):
try:
return Template(template).substitute(d)
except KeyError as e:
raise GrondError(
'invalid placeholder "%s" in template: "%s"' % (str(e), template))
except ValueError:
raise GrondError(
'malformed placeholder in template: "%s"' % template)
def process_event(ievent, g_data_id):
config, force, status, nparallel, event_names = g_state[g_data_id]
......@@ -1679,28 +1661,27 @@ def process_event(ievent, g_data_id):
if nparallel > 1:
status = ()
ds = config.get_dataset()
events = ds.get_events(event_names=event_names)
nevents = len(events)
event_name = event_names[ievent]
event = events[ievent]
ds = config.get_dataset(event_name)
ds.empty_cache()
nevents = len(event_names)
tstart = time.time()
event = ds.get_event()
problem = config.get_problem(event)
# FIXME
synt = ds.synthetic_test
if synt and synt.inject_solution:
if synt:
problem.base_source = problem.unpack(synt.get_x())
check_problem(problem)
rundir = config.rundir_template % dict(
problem_name=problem.name)
rundir = substitute_template(
config.rundir_template,
dict(problem_name=problem.name))
if op.exists(rundir):
if force:
......
......@@ -53,7 +53,7 @@ def dump_station_corrections(station_corrections, filename):
class Dataset(object):
def __init__(self):
def __init__(self, event_name=None):
self.events = []
self.pile = pile.Pile()
self.stations = {}
......@@ -72,6 +72,7 @@ class Dataset(object):
self.synthetic_test = None
self._picks = None
self._cache = {}
self._event_name = event_name
def empty_cache(self):
self._cache = {}
......@@ -493,15 +494,16 @@ class Dataset(object):
syn_test = self.synthetic_test
toffset_noise_extract = 0.0
if syn_test:
if syn_test.ignore_data_availability:
if not syn_test.respect_data_availability:
if syn_test.add_real_noise:
raise DatasetError(
'ignore_data_availability=True and '
'respect_data_availability=False and '
'add_real_noise=True cannot be combined.')
tr = syn_test.get_waveform(
nslc, tmin, tmax,
tfade=tfade, freqlimits=freqlimits)
tfade=tfade,
freqlimits=freqlimits)
if cache is not None:
cache[tr.nslc_id, tmin, tmax] = tr
......@@ -595,7 +597,7 @@ class Dataset(object):
return evs
def get_event(self, t, magmin=None):
def get_event_by_time(self, t, magmin=None):
evs = self.get_events(magmin=magmin)
ev_x = None
for ev in evs:
......@@ -609,6 +611,16 @@ class Dataset(object):
return ev_x
def get_event(self):
if self._event_name is None:
raise NotFound('no main event selected in dataset')
for ev in self.events:
if ev.name == self._event_name:
return ev
raise NotFound('no such event: %s' % self._event_name)
def get_picks(self):
if self._picks is None:
hash_to_name = {}
......
......@@ -1458,8 +1458,9 @@ def plot_result(dirname, plotnames_want,
if plotname in plotnames_want:
config = guts.load(filename=op.join(dirname, 'config.yaml'))
config.set_basepath(dirname)
problem.set_engine(config.engine_config.get_engine())
ds = config.get_dataset()
config.setup_modelling_environment(problem)
event_name = problem.base_source.name
ds = config.get_dataset(event_name)
figs = plot_dispatch[plotname](ds, model, plt)
if save:
fns.extend(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment