Commit 36b3b9e7 authored by Sebastian Heimann's avatar Sebastian Heimann
Browse files

chunked stacking to reduce memory hunger

parent 3ff5f56e
...@@ -21,7 +21,7 @@ def d2u(d): ...@@ -21,7 +21,7 @@ def d2u(d):
def str_to_time(s): def str_to_time(s):
try: try:
return util.str_to_time(s) return util.str_to_time(s)
except util.TimeStrError, e: except util.TimeStrError as e:
raise lassie.LassieError(str(e)) raise lassie.LassieError(str(e))
...@@ -356,7 +356,7 @@ def command_search(args): ...@@ -356,7 +356,7 @@ def command_search(args):
nparallel=nparallel, nparallel=nparallel,
bark=options.bark) bark=options.bark)
except lassie.LassieError, e: except lassie.LassieError as e:
die(str(e)) die(str(e))
......
import logging import logging
import os.path as op import os.path as op
from pyrocko.guts import Object, String, Float, Timestamp, List, Bool from pyrocko.guts import String, Float, Timestamp, List, Bool, Int
from pyrocko import model, guts from pyrocko import model, guts
from pyrocko.fdsn import station as fs from pyrocko.fdsn import station as fs
from pyrocko.gf import TPDef from pyrocko.gf import TPDef
...@@ -120,6 +120,18 @@ class Config(HasPaths): ...@@ -120,6 +120,18 @@ class Config(HasPaths):
default='lassie_phases.cache', default='lassie_phases.cache',
help='directory where lassie stores tabulated phases etc.') help='directory where lassie stores tabulated phases etc.')
stacking_blocksize = Int.T(
optional=True,
help='enable chunked stacking to reduce memory usage. Setting this to '
'e.g. 64 will use ngridpoints * 64 * 8 bytes of memory to hold '
'the stacking results, instead of computing the whole processing '
'time window in one shot. Setting this to a very small number '
'may lead to bad performance. If this is enabled together with '
'plotting, the cutout of the image function seen in the map '
'image must be stacked again just for plotting (redundantly and '
'memory greedy) because it may intersect more than one '
'processing chunk.')
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
HasPaths.__init__(self, *args, **kwargs) HasPaths.__init__(self, *args, **kwargs)
self._receivers = None self._receivers = None
...@@ -131,8 +143,8 @@ class Config(HasPaths): ...@@ -131,8 +143,8 @@ class Config(HasPaths):
''' '''
Post-init setup of image function contributors. Post-init setup of image function contributors.
''' '''
for ifc in self.image_function_contributions: for ifc_ in self.image_function_contributions:
ifc.setup(self) ifc_.setup(self)
def set_config_name(self, config_name): def set_config_name(self, config_name):
self._config_name = config_name self._config_name = config_name
......
...@@ -2,6 +2,7 @@ import math ...@@ -2,6 +2,7 @@ import math
import logging import logging
import os.path as op import os.path as op
import shutil import shutil
import time
from collections import defaultdict from collections import defaultdict
...@@ -312,10 +313,14 @@ def search( ...@@ -312,10 +313,14 @@ def search(
if config.fill_incomplete_with_zeros: if config.fill_incomplete_with_zeros:
trs = zero_fill(trs, wmin - tpad, wmax + tpad) trs = zero_fill(trs, wmin - tpad, wmax + tpad)
frames = None t0 = math.floor(wmin / deltat_cf) * deltat_cf
iwmin = int(round((wmin-tpeaksearch-t0) / deltat_cf))
iwmax = int(round((wmax+tpeaksearch-t0) / deltat_cf))
lengthout = iwmax - iwmin + 1
pdata = [] pdata = []
trs_debug = [] trs_debug = []
shift_maxs = [] parstack_params = []
for iifc, ifc in enumerate(ifcs): for iifc, ifc in enumerate(ifcs):
dataset = ifc.preprocess( dataset = ifc.preprocess(
trs, wmin-tpeaksearch, wmax+tpeaksearch, trs, wmin-tpeaksearch, wmax+tpeaksearch,
...@@ -333,8 +338,6 @@ def search( ...@@ -333,8 +338,6 @@ def search(
trs_debug.extend(trs + list(trs_selected)) trs_debug.extend(trs + list(trs_selected))
t0 = (wmin / deltat_cf) * deltat_cf
istations_selected = num.array( istations_selected = num.array(
[station_index[nsl] for nsl in nsls_selected], [station_index[nsl] for nsl in nsls_selected],
dtype=num.int) dtype=num.int)
...@@ -366,35 +369,58 @@ def search( ...@@ -366,35 +369,58 @@ def search(
pdata.append((list(trs_selected), shift_table, ifc)) pdata.append((list(trs_selected), shift_table, ifc))
iwmin = int(round((wmin-tpeaksearch-t0) / deltat_cf)) parstack_params.append((arrays, offsets, shifts, weights))
iwmax = int(round((wmax+tpeaksearch-t0) / deltat_cf))
if config.stacking_blocksize is not None:
lengthout = iwmax - iwmin + 1 ipstep = config.stacking_blocksize
frames = None
shift_maxs.append(num.max(-shifts) * deltat_cf) else:
ipstep = lengthout
frames, ioff = parstack( frames = num.zeros((ngridpoints, lengthout))
arrays, offsets, shifts, weights, 0,
offsetout=iwmin, twall_start = time.time()
lengthout=lengthout, frame_maxs = num.zeros(lengthout)
result=frames, frame_argmaxs = num.zeros(lengthout, dtype=num.int)
nparallel=nparallel, ipmin = iwmin
impl='openmp') while ipmin < iwmin+lengthout:
ipsize = min(ipstep, iwmin + lengthout - ipmin)
shift_max = max(shift_maxs) if ipstep == lengthout:
frames_p = frames
if config.sharpness_normalization: else:
frame_maxs = frames.max(axis=0) frames_p = num.zeros((ngridpoints, ipsize))
frame_means = num.abs(frames).mean(axis=0)
frames *= (frame_maxs / frame_means)[num.newaxis, :] for (arrays, offsets, shifts, weights) in parstack_params:
frames *= norm_map[:, num.newaxis] frames_p, _ = parstack(
arrays, offsets, shifts, weights, 0,
if config.ifc_count_normalization: offsetout=ipmin,
frames *= 1.0 / len(ifcs) lengthout=ipsize,
result=frames_p,
frame_maxs = frames.max(axis=0) nparallel=nparallel,
impl='openmp')
tmin_frames = t0 + ioff * deltat_cf
if config.sharpness_normalization:
frame_p_maxs = frames_p.max(axis=0)
frame_p_means = num.abs(frames_p).mean(axis=0)
frames_p *= (frame_p_maxs / frame_p_means)[num.newaxis, :]
frames_p *= norm_map[:, num.newaxis]
if config.ifc_count_normalization:
frames_p *= 1.0 / len(ifcs)
frame_maxs[ipmin-iwmin:ipmin-iwmin+ipsize] = \
frames_p.max(axis=0)
frame_argmaxs[ipmin-iwmin:ipmin-iwmin+ipsize] = \
pargmax(frames_p)
ipmin += ipstep
del frames_p
twall_end = time.time()
logger.info('wallclock time for stacking: %g s' % (
twall_end - twall_start))
tmin_frames = t0 + iwmin * deltat_cf
tr_stackmax = trace.Trace( tr_stackmax = trace.Trace(
'', 'SMAX', '', '', '', 'SMAX', '', '',
...@@ -422,17 +448,13 @@ def search( ...@@ -422,17 +448,13 @@ def search(
wmin <= tpeak and tpeak < wmax]) or ([], []) wmin <= tpeak and tpeak < wmax]) or ([], [])
tr_stackmax_indx = tr_stackmax.copy(data=False) tr_stackmax_indx = tr_stackmax.copy(data=False)
tr_stackmax_indx.set_ydata(frame_argmaxs.astype(num.int32))
imaxs = pargmax(frames)
tr_stackmax_indx.set_ydata(imaxs.astype(num.int32))
tr_stackmax_indx.set_location('i') tr_stackmax_indx.set_location('i')
for (tpeak, apeak) in zip(tpeaks, apeaks): for (tpeak, apeak) in zip(tpeaks, apeaks):
iframe = int(round(((tpeak-t0) - ioff*deltat_cf) / deltat_cf)) iframe = int(round((tpeak - tmin_frames) / deltat_cf))
frame = frames[:, iframe] imax = frame_argmaxs[iframe]
imax = imaxs[iframe]
latpeak, lonpeak, xpeak, ypeak, zpeak = \ latpeak, lonpeak, xpeak, ypeak, zpeak = \
grid.index_to_location(imax) grid.index_to_location(imax)
...@@ -481,23 +503,60 @@ def search( ...@@ -481,23 +503,60 @@ def search(
util.ensuredirs(fn) util.ensuredirs(fn)
try: if frames is not None:
plot.plot_detection( frames_p = frames
grid, receivers, frames, tmin_frames, tmin_frames_p = tmin_frames
deltat_cf, imax, iframe, fsmooth_min, xpeak, ypeak, iframe_p = iframe
zpeak,
tr_stackmax, tpeaks, apeaks, else:
config.detector_threshold, iframe_min = max(
wmin, wmax, 0,
pdata, trs, fmin, fmax, idetection, int(round(iframe - tpeaksearch/deltat_cf)))
grid_station_shift_max=shift_max, iframe_max = min(
movie=show_movie, lengthout-1,
show=show_detections, int(round(iframe + tpeaksearch/deltat_cf)))
save_filename=fn)
except AttributeError as e: ipsize = iframe_max - iframe_min + 1
logger.warn(e) frames_p = num.zeros((ngridpoints, ipsize))
tmin_frames_p = tmin_frames + iframe_min*deltat_cf
del frame iframe_p = iframe - iframe_min
for (arrays, offsets, shifts, weights) \
in parstack_params:
frames_p, _ = parstack(
arrays, offsets, shifts, weights, 0,
offsetout=iwmin+iframe_min,
lengthout=ipsize,
result=frames_p,
nparallel=nparallel,
impl='openmp')
if config.sharpness_normalization:
frame_p_maxs = frames_p.max(axis=0)
frame_p_means = num.abs(frames_p).mean(axis=0)
frames_p *= (
frame_p_maxs/frame_p_means)[num.newaxis, :]
frames_p *= norm_map[:, num.newaxis]
if config.ifc_count_normalization:
frames_p *= 1.0 / len(ifcs)
plot.plot_detection(
grid, receivers, frames_p, tmin_frames_p,
deltat_cf, imax, iframe_p, xpeak, ypeak,
zpeak,
tr_stackmax, tpeaks, apeaks,
config.detector_threshold,
wmin, wmax,
pdata, trs, fmin, fmax, idetection,
tpeaksearch,
movie=show_movie,
show=show_detections,
save_filename=fn)
del frames_p
if stop_after_first: if stop_after_first:
return return
......
import os
import numpy as num import numpy as num
from pyrocko import automap, plot, util from pyrocko import automap, plot, util
...@@ -107,9 +106,9 @@ def plot_geometry_carthesian(grid, receivers): ...@@ -107,9 +106,9 @@ def plot_geometry_carthesian(grid, receivers):
def plot_detection( def plot_detection(
grid, receivers, frames, tmin_frames, deltat_cf, imax, iframe, grid, receivers, frames, tmin_frames, deltat_cf, imax, iframe,
fsmooth_min, xpeak, ypeak, zpeak, tr_stackmax, tpeaks, apeaks, xpeak, ypeak, zpeak, tr_stackmax, tpeaks, apeaks,
detector_threshold, wmin, wmax, pdata, trs_raw, fmin, fmax, detector_threshold, wmin, wmax, pdata, trs_raw, fmin, fmax,
idetection, grid_station_shift_max, idetection, tpeaksearch,
movie=False, save_filename=None, show=True): movie=False, save_filename=None, show=True):
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
...@@ -122,11 +121,10 @@ def plot_detection( ...@@ -122,11 +121,10 @@ def plot_detection(
distances = grid.distances(receivers) distances = grid.distances(receivers)
plot.mpl_init() plot.mpl_init(fontsize=9)
fig = plt.figure(figsize=plot.mpl_papersize('a4', 'landscape')) fig = plt.figure(figsize=plot.mpl_papersize('a4', 'landscape'))
axes = plt.subplot2grid((2, 3), (0, 2), aspect=1.0) axes = plt.subplot2grid((2, 3), (0, 2), aspect=1.0)
plot.mpl_labelspace(axes) plot.mpl_labelspace(axes)
...@@ -160,7 +158,7 @@ def plot_detection( ...@@ -160,7 +158,7 @@ def plot_detection(
tpeak_current = tmin_frames + deltat_cf * iframe tpeak_current = tmin_frames + deltat_cf * iframe
t0 = tpeak_current t0 = tpeak_current
tduration = 2.0*grid_station_shift_max + 1./fsmooth_min tduration = 2.0*tpeaksearch
axes2.axvspan( axes2.axvspan(
tr_stackmax.tmin - t0, wmin - t0, tr_stackmax.tmin - t0, wmin - t0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment