Commit 5d1cbdbd authored by Sebastian Heimann's avatar Sebastian Heimann
Browse files

restructuring for better handling of event datasets and synthetic tests

parent cb04b480
...@@ -34,9 +34,11 @@ subcommand_descriptions = { ...@@ -34,9 +34,11 @@ subcommand_descriptions = {
subcommand_usages = { subcommand_usages = {
'init': 'init [options]', 'init': 'init [options]',
'check': 'check <configfile> [options]', 'check': 'check <configfile> <eventnames> ... [options]',
'go': 'go <configfile> [options]', 'go': 'go <configfile> <eventnames> ... [options]',
'forward': 'forward <rundir> [options]', 'forward': (
'forward <rundir> [options]',
'forward <configfile> <eventnames> ... [options]'),
'harvest': 'harvest <rundir> [options]', 'harvest': 'harvest <rundir> [options]',
'plot': 'plot <plotnames> <rundir> [options]', 'plot': 'plot <plotnames> <rundir> [options]',
'export': 'export (best|mean|ensemble|stats) <rundirs> ... [options]', 'export': 'export (best|mean|ensemble|stats) <rundirs> ... [options]',
...@@ -173,19 +175,15 @@ def command_init(args): ...@@ -173,19 +175,15 @@ def command_init(args):
def command_check(args): def command_check(args):
def setup(parser): def setup(parser):
parser.add_option( pass
'--event', dest='event_name', metavar='EVENTNAME',
help='process only event EVENTNAME')
parser, options, args = cl_parse('check', args, setup) parser, options, args = cl_parse('check', args, setup)
if len(args) != 1: if len(args) < 2:
help_and_die(parser, 'no config file given') help_and_die(parser, 'missing arguments')
event_names = None
if options.event_name:
event_names = [options.event_name]
config_path = args[0] config_path = args[0]
event_names = args[1:]
config = grond.read_config(config_path) config = grond.read_config(config_path)
grond.check( grond.check(
...@@ -195,9 +193,6 @@ def command_check(args): ...@@ -195,9 +193,6 @@ def command_check(args):
def command_go(args): def command_go(args):
def setup(parser): def setup(parser):
parser.add_option(
'--event', dest='event_name', metavar='EVENTNAME',
help='process only event EVENTNAME')
parser.add_option( parser.add_option(
'--force', dest='force', action='store_true', '--force', dest='force', action='store_true',
help='overwrite existing run directory') help='overwrite existing run directory')
...@@ -209,20 +204,18 @@ def command_go(args): ...@@ -209,20 +204,18 @@ def command_go(args):
help='set number of events to process in parallel') help='set number of events to process in parallel')
parser, options, args = cl_parse('go', args, setup) parser, options, args = cl_parse('go', args, setup)
if len(args) != 1: if len(args) < 2:
help_and_die(parser, 'no config file given') help_and_die(parser, 'missing arguments')
config_path = args[0] config_path = args[0]
event_names = args[1:]
config = grond.read_config(config_path) config = grond.read_config(config_path)
if options.status == 'quiet': if options.status == 'quiet':
status = () status = ()
else: else:
status = tuple(options.status.split(',')) status = tuple(options.status.split(','))
event_names = None
if options.event_name:
event_names = [options.event_name]
grond.go( grond.go(
config, config,
event_names=event_names, event_names=event_names,
...@@ -233,19 +226,15 @@ def command_go(args): ...@@ -233,19 +226,15 @@ def command_go(args):
def command_forward(args): def command_forward(args):
def setup(parser): def setup(parser):
parser.add_option( pass
'--event', dest='event_name', metavar='EVENTNAME',
help='process only event EVENTNAME')
parser, options, args = cl_parse('forward', args, setup) parser, options, args = cl_parse('forward', args, setup)
if len(args) != 1: if len(args) < 1:
help_and_die(parser, 'incorrect number of arguments') help_and_die(parser, 'missing arguments')
event_names = None event_names = args[1:]
if options.event_name:
event_names = [options.event_name]
run_path, = args run_path = args[0]
grond.forward( grond.forward(
run_path, run_path,
event_names=event_names) event_names=event_names)
......
...@@ -37,7 +37,6 @@ class CMTProblem(core.Problem): ...@@ -37,7 +37,6 @@ class CMTProblem(core.Problem):
core.Parameter('rel_moment_iso', label='$M_{0}^{ISO}/M_{0}$'), core.Parameter('rel_moment_iso', label='$M_{0}^{ISO}/M_{0}$'),
core.Parameter('rel_moment_clvd', label='$M_{0}^{CLVD}/M_{0}$')] core.Parameter('rel_moment_clvd', label='$M_{0}^{CLVD}/M_{0}$')]
base_source = gf.Source.T()
targets = List.T(core.MisfitTarget.T()) targets = List.T(core.MisfitTarget.T())
ranges = Dict.T(String.T(), gf.Range.T()) ranges = Dict.T(String.T(), gf.Range.T())
...@@ -303,7 +302,7 @@ class CMTProblemConfig(core.ProblemConfig): ...@@ -303,7 +302,7 @@ class CMTProblemConfig(core.ProblemConfig):
event_time=util.time_to_str(event.time)) event_time=util.time_to_str(event.time))
problem = CMTProblem( problem = CMTProblem(
name=self.name_template % subs, name=core.substitute_template(self.name_template, subs),
apply_balancing_weights=self.apply_balancing_weights, apply_balancing_weights=self.apply_balancing_weights,
base_source=base_source, base_source=base_source,
targets=targets, targets=targets,
......
...@@ -6,6 +6,7 @@ import time ...@@ -6,6 +6,7 @@ import time
import copy import copy
import shutil import shutil
import os.path as op import os.path as op
from string import Template
import numpy as num import numpy as num
...@@ -122,6 +123,7 @@ class Problem(Object): ...@@ -122,6 +123,7 @@ class Problem(Object):
parameters = List.T(Parameter.T()) parameters = List.T(Parameter.T())
dependants = List.T(Parameter.T()) dependants = List.T(Parameter.T())
apply_balancing_weights = Bool.T(default=True) apply_balancing_weights = Bool.T(default=True)
base_source = gf.Source.T()
def __init__(self, **kwargs): def __init__(self, **kwargs):
Object.__init__(self, **kwargs) Object.__init__(self, **kwargs)
...@@ -666,45 +668,27 @@ class SyntheticWaveformNotAvailable(Exception): ...@@ -666,45 +668,27 @@ class SyntheticWaveformNotAvailable(Exception):
class SyntheticTest(Object): class SyntheticTest(Object):
random_seed = Int.T(default=0)
inject_solution = Bool.T(default=False) inject_solution = Bool.T(default=False)
ignore_data_availability = Bool.T(default=False) respect_data_availability = Bool.T(default=False)
add_real_noise = Bool.T(default=False) add_real_noise = Bool.T(default=False)
toffset_real_noise = Float.T(default=-3600.) toffset_real_noise = Float.T(default=-3600.)
x = Dict.T(String.T(), Float.T()) x = Dict.T(String.T(), Float.T())
def __init__(self, **kwargs): def __init__(self, **kwargs):
Object.__init__(self, **kwargs) Object.__init__(self, **kwargs)
self._problem = None
self._synthetics = None self._synthetics = None
self._rstate = num.random.RandomState(self.random_seed)
def set_config(self, config): def set_problem(self, problem):
self._config = config self._problem = problem
self._synthetics = None
def get_problem(self): def get_problem(self):
ds = self._config.get_dataset() if self._problem is None:
events = ds.get_events() raise SyntheticWaveformNotAvailable(
event = events[0] 'SyntheticTest.set_problem() has not been called yet')
return self._config.get_problem(event)
def get_x_random(self):
problem = self.get_problem()
xbounds = num.array(problem.bounds(), dtype=num.float)
npar = xbounds.shape[0]
x = num.zeros(npar, dtype=num.float)
while True:
for i in xrange(npar):
x[i] = self._rstate.uniform(xbounds[i, 0], xbounds[i, 1])
try:
x = problem.preconstrain(x)
break
except Forbidden:
pass
return x return self._problem
def get_x(self): def get_x(self):
problem = self.get_problem() problem = self.get_problem()
...@@ -713,22 +697,16 @@ class SyntheticTest(Object): ...@@ -713,22 +697,16 @@ class SyntheticTest(Object):
problem.parameter_array(self.x)) problem.parameter_array(self.x))
else: else:
print problem.base_source
x = problem.preconstrain( x = problem.preconstrain(
problem.pack( problem.pack(
problem.base_source)) problem.base_source))
print x
print problem.unpack(x)
return x return x
def get_synthetics(self): def get_synthetics(self):
problem = self.get_problem()
if self._synthetics is None: if self._synthetics is None:
problem = self.get_problem()
x = self.get_x() x = self.get_x()
results = problem.forward(x) results = problem.forward(x)
self._synthetics = results self._synthetics = results
...@@ -775,12 +753,12 @@ class DatasetConfig(HasPaths): ...@@ -775,12 +753,12 @@ class DatasetConfig(HasPaths):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
HasPaths.__init__(self, *args, **kwargs) HasPaths.__init__(self, *args, **kwargs)
self._ds = None self._ds = {}
def get_dataset(self): def get_dataset(self, event_name):
if self._ds is None: if event_name not in self._ds:
fp = self.expand_path fp = self.expand_path
ds = dataset.Dataset() ds = dataset.Dataset(event_name)
ds.add_stations( ds.add_stations(
pyrocko_stations_filename=fp(self.stations_path), pyrocko_stations_filename=fp(self.stations_path),
stationxml_filenames=fp(self.stations_stationxml_paths)) stationxml_filenames=fp(self.stations_stationxml_paths))
...@@ -813,10 +791,10 @@ class DatasetConfig(HasPaths): ...@@ -813,10 +791,10 @@ class DatasetConfig(HasPaths):
if self.whitelist: if self.whitelist:
ds.add_whitelist(self.whitelist) ds.add_whitelist(self.whitelist)
ds.set_synthetic_test(self.synthetic_test) ds.set_synthetic_test(copy.deepcopy(self.synthetic_test))
self._ds = ds self._ds[event_name] = ds
return self._ds return self._ds[event_name]
def weed(origin, targets, limit, neighborhood=3): def weed(origin, targets, limit, neighborhood=3):
...@@ -961,15 +939,11 @@ class Config(HasPaths): ...@@ -961,15 +939,11 @@ class Config(HasPaths):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
HasPaths.__init__(self, *args, **kwargs) HasPaths.__init__(self, *args, **kwargs)
def get_dataset(self): def get_dataset(self, event_name):
ds = self.dataset_config.get_dataset() return self.dataset_config.get_dataset(event_name)
if ds.synthetic_test:
ds.synthetic_test.set_config(self)
return ds
def get_targets(self, event): def get_targets(self, event):
ds = self.get_dataset() ds = self.get_dataset(event.name)
targets = [] targets = []
for igroup, target_config in enumerate(self.target_configs): for igroup, target_config in enumerate(self.target_configs):
...@@ -978,10 +952,18 @@ class Config(HasPaths): ...@@ -978,10 +952,18 @@ class Config(HasPaths):
return targets return targets
def setup_modelling_environment(self, problem):
problem.set_engine(self.engine_config.get_engine())
ds = self.get_dataset(problem.base_source.name)
synt = ds.synthetic_test
if synt:
synt.set_problem(problem)
problem.base_source = problem.unpack(synt.get_x())
def get_problem(self, event): def get_problem(self, event):
targets = self.get_targets(event) targets = self.get_targets(event)
problem = self.problem_config.get_problem(event, targets) problem = self.problem_config.get_problem(event, targets)
problem.set_engine(self.engine_config.get_engine()) self.setup_modelling_environment(problem)
return problem return problem
...@@ -1449,7 +1431,7 @@ def forward(rundir_or_config_path, event_names=None): ...@@ -1449,7 +1431,7 @@ def forward(rundir_or_config_path, event_names=None):
ibest = num.argmin(gms) ibest = num.argmin(gms)
xbest = xs[ibest, :] xbest = xs[ibest, :]
ds = config.get_dataset() ds = config.get_dataset(problem.base_source.name)
problem.set_engine(config.engine_config.get_engine()) problem.set_engine(config.engine_config.get_engine())
for target in problem.targets: for target in problem.targets:
...@@ -1459,11 +1441,11 @@ def forward(rundir_or_config_path, event_names=None): ...@@ -1459,11 +1441,11 @@ def forward(rundir_or_config_path, event_names=None):
else: else:
config = read_config(rundir_or_config_path) config = read_config(rundir_or_config_path)
ds = config.get_dataset()
events = ds.get_events(event_names=event_names)
payload = [] payload = []
for event in events: for event_name in event_names:
ds = config.get_dataset(event_name)
event = ds.get_event()
problem = config.get_problem(event) problem = config.get_problem(event)
xref = problem.preconstrain( xref = problem.preconstrain(
problem.pack(problem.base_source)) problem.pack(problem.base_source))
...@@ -1556,17 +1538,12 @@ g_state = {} ...@@ -1556,17 +1538,12 @@ g_state = {}
def check(config, event_names=None): def check(config, event_names=None):
ds = config.get_dataset()
events = ds.get_events(event_names=event_names)
nevents = len(events)
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from grond.plot import colors from grond.plot import colors
if nevents == 0: for ievent, event_name in enumerate(event_names):
raise GrondError('no events found') ds = config.get_dataset(event_name)
event = ds.get_event()
for ievent, event in enumerate(events):
try: try:
problem = config.get_problem(event) problem = config.get_problem(event)
check_problem(problem) check_problem(problem)
...@@ -1651,18 +1628,12 @@ def go(config, event_names=None, force=False, nparallel=1, status=('state',)): ...@@ -1651,18 +1628,12 @@ def go(config, event_names=None, force=False, nparallel=1, status=('state',)):
status = tuple(status) status = tuple(status)
ds = config.get_dataset()
events = ds.get_events(event_names=event_names)
nevents = len(events)
if nevents == 0:
raise GrondError('no events found')
g_data = (config, force, status, nparallel, event_names) g_data = (config, force, status, nparallel, event_names)
g_state[id(g_data)] = g_data g_state[id(g_data)] = g_data
nevents = len(event_names)
for x in parimap.parimap( for x in parimap.parimap(
process_event, process_event,
xrange(nevents), xrange(nevents),
...@@ -1672,6 +1643,17 @@ def go(config, event_names=None, force=False, nparallel=1, status=('state',)): ...@@ -1672,6 +1643,17 @@ def go(config, event_names=None, force=False, nparallel=1, status=('state',)):
pass pass
def substitute_template(template, d):
try:
return Template(template).substitute(d)
except KeyError as e:
raise GrondError(
'invalid placeholder "%s" in template: "%s"' % (str(e), template))
except ValueError:
raise GrondError(
'malformed placeholder in template: "%s"' % template)
def process_event(ievent, g_data_id): def process_event(ievent, g_data_id):
config, force, status, nparallel, event_names = g_state[g_data_id] config, force, status, nparallel, event_names = g_state[g_data_id]
...@@ -1679,28 +1661,27 @@ def process_event(ievent, g_data_id): ...@@ -1679,28 +1661,27 @@ def process_event(ievent, g_data_id):
if nparallel > 1: if nparallel > 1:
status = () status = ()
ds = config.get_dataset() event_name = event_names[ievent]
events = ds.get_events(event_names=event_names)
nevents = len(events)
event = events[ievent] ds = config.get_dataset(event_name)
ds.empty_cache() nevents = len(event_names)
tstart = time.time() tstart = time.time()
event = ds.get_event()
problem = config.get_problem(event) problem = config.get_problem(event)
# FIXME
synt = ds.synthetic_test synt = ds.synthetic_test
if synt and synt.inject_solution: if synt:
problem.base_source = problem.unpack(synt.get_x()) problem.base_source = problem.unpack(synt.get_x())
check_problem(problem) check_problem(problem)
rundir = config.rundir_template % dict( rundir = substitute_template(
problem_name=problem.name) config.rundir_template,
dict(problem_name=problem.name))
if op.exists(rundir): if op.exists(rundir):
if force: if force:
......
...@@ -53,7 +53,7 @@ def dump_station_corrections(station_corrections, filename): ...@@ -53,7 +53,7 @@ def dump_station_corrections(station_corrections, filename):
class Dataset(object): class Dataset(object):
def __init__(self): def __init__(self, event_name=None):
self.events = [] self.events = []
self.pile = pile.Pile() self.pile = pile.Pile()
self.stations = {} self.stations = {}
...@@ -72,6 +72,7 @@ class Dataset(object): ...@@ -72,6 +72,7 @@ class Dataset(object):
self.synthetic_test = None self.synthetic_test = None
self._picks = None self._picks = None
self._cache = {} self._cache = {}
self._event_name = event_name
def empty_cache(self): def empty_cache(self):
self._cache = {} self._cache = {}
...@@ -493,15 +494,16 @@ class Dataset(object): ...@@ -493,15 +494,16 @@ class Dataset(object):
syn_test = self.synthetic_test syn_test = self.synthetic_test
toffset_noise_extract = 0.0 toffset_noise_extract = 0.0
if syn_test: if syn_test:
if syn_test.ignore_data_availability: if not syn_test.respect_data_availability:
if syn_test.add_real_noise: if syn_test.add_real_noise:
raise DatasetError( raise DatasetError(
'ignore_data_availability=True and ' 'respect_data_availability=False and '
'add_real_noise=True cannot be combined.') 'add_real_noise=True cannot be combined.')
tr = syn_test.get_waveform( tr = syn_test.get_waveform(
nslc, tmin, tmax, nslc, tmin, tmax,
tfade=tfade, freqlimits=freqlimits) tfade=tfade,
freqlimits=freqlimits)
if cache is not None: if cache is not None:
cache[tr.nslc_id, tmin, tmax] = tr cache[tr.nslc_id, tmin, tmax] = tr
...@@ -595,7 +597,7 @@ class Dataset(object): ...@@ -595,7 +597,7 @@ class Dataset(object):
return evs return evs
def get_event(self, t, magmin=None): def get_event_by_time(self, t, magmin=None):
evs = self.get_events(magmin=magmin) evs = self.get_events(magmin=magmin)
ev_x = None ev_x = None
for ev in evs: for ev in evs:
...@@ -609,6 +611,16 @@ class Dataset(object): ...@@ -609,6 +611,16 @@ class Dataset(object):
return ev_x return ev_x
def get_event(self):
if self._event_name is None:
raise NotFound('no main event selected in dataset')
for ev in self.events:
if ev.name == self._event_name:
return ev