Commit a3dc87ad authored by Sebastian Heimann's avatar Sebastian Heimann
Browse files

improve speed when many ivalid/unused targets are present

parent be1a2aaf
......@@ -163,27 +163,45 @@ class CMTProblem(core.Problem):
return out
def evaluate(self, x, result_mode='sparse'):
def evaluate(self, x, result_mode='sparse', mask=None):
source = self.unpack(x)
engine = self.get_engine()
for target in self.targets:
target.set_result_mode(result_mode)
resp = engine.process(source, self.targets)
if mask is not None:
assert len(mask) == len(self.targets)
targets_ok = [
target for (target, ok) in zip(self.targets, mask) if ok]
else:
targets_ok = self.targets
resp = engine.process(source, targets_ok)
if mask is not None:
ires_ok = 0
results = []
for target, ok in zip(self.targets, mask):
if ok:
results.append(resp.results_list[0][ires_ok])
ires_ok += 1
else:
results.append(
gf.SeismosizerError(
'skipped because of previous failure'))
else:
results = list(resp.results_list[0])
data = []
results = []
for target, result in zip(self.targets, resp.results_list[0]):
for target, result in zip(self.targets, results):
if isinstance(result, gf.SeismosizerError):
logger.debug(
'%s.%s.%s.%s: %s' % (target.codes + (str(result),)))
data.append((None, None))
results.append(result)
else:
data.append((result.misfit_value, result.misfit_norm))
results.append(result)
ms, ns = num.array(data, dtype=num.float).T
if result_mode == 'full':
......
......@@ -1207,6 +1207,7 @@ def analyse(problem, niter=1000, show_progress=False):
if show_progress:
pbar = util.progressbar('analysing problem', niter)
isbad_mask = None
for iiter in xrange(niter):
while True:
x = []
......@@ -1221,9 +1222,16 @@ def analyse(problem, niter=1000, show_progress=False):
except Forbidden:
pass
_, ms = wproblem.evaluate(x)
if isbad_mask is not None and num.any(isbad_mask):
isok_mask = num.logical_not(isbad_mask)
else:
isok_mask = None
_, ms = wproblem.evaluate(x, mask=isok_mask)
mss[iiter, :] = ms
isbad_mask = num.isnan(ms)
if show_progress:
pbar.update(iiter)
......@@ -1391,7 +1399,12 @@ def solve(problem,
except Forbidden:
pass
ms, ns = problem.evaluate(x)
if isbad_mask is not None and num.any(isbad_mask):
isok_mask = num.logical_not(isbad_mask)
else:
isok_mask = None
ms, ns = problem.evaluate(x, mask=isok_mask)
isbad_mask_new = num.isnan(ms)
if isbad_mask is not None and num.any(isbad_mask != isbad_mask_new):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment