Commit d04383de authored by Marius Isken's avatar Marius Isken
Browse files

Added TerminalListener // Cleaned HighScoreSolver

parent c4a8ae8b
......@@ -15,7 +15,8 @@ setup(
version='0.1',
author='Sebastian Heimann',
author_email='sebastian.heimann@gfz-potsdam.de',
packages=['grond', 'grond.baraddur', 'grond.problems', 'grond.solvers', 'grond.analysers'],
packages=['grond', 'grond.baraddur', 'grond.problems', 'grond.solvers',
'grond.analysers', 'grond.listeners'],
scripts=['apps/grond'],
package_dir={'grond': 'src'},
package_data={'grond': ['baraddur/templates/*.html',
......
......@@ -15,6 +15,7 @@ from .dataset import DatasetConfig, NotFound
from .problems.base import ProblemConfig, Problem
from .solvers.base import SolverConfig
from .analysers.base import AnalyserConfig
from .listeners import TerminalListener
from .targets import TargetConfig
from .meta import Path, HasPaths, expand_template, xjoin, GrondError, Notifier
......@@ -704,6 +705,7 @@ def process_event(ievent, g_data_id):
'start %i / %i' % (ievent+1, nevents))
notifier = Notifier()
notifier.add_listener(TerminalListener())
analyser = config.analyser_config.get_analyser()
analyser.analyse(problem, notifier=notifier)
......
from .curses import CursesListener # noqa
from .terminal import TerminalListener # noqa
class Listener(object):
def progress_start(self, name, niter):
raise NotImplementedError()
def progress_finish(self, name):
raise NotImplementedError()
def progress_update(self, name, iiter):
raise NotImplementedError()
def state(self, state):
raise NotImplementedError()
import curses
class State(object):
iiter = 0
niter = 0
iter_sec = 0.
problem_name = ''
parameter_names = []
column_names = []
values = []
text = ''
class _CursesPad(object):
def __init__(self, pad):
self.pad = pad
self.rows, self.cols = self.pad.getyx()
def resize_pad(self):
return
self.pad.resize(self.rows, self.cols)
class CursesListener(object):
class ParameterTable(_CursesPad):
value_fmt = '{0:8.4g}'
column_padding = 2
def update(self, state):
pad = self.pad
pad.clear()
if not state:
return
parameter_names = ['Parameters'] + state.parameters
col = 0
for icol in xrange(len(state.values) + 1):
row = 0
if icol == 0:
col_width = max([len(p) for p in parameter_names])
for name in parameter_names:
pad.addstr(
row, col,
'{:<{width}}'.format(
name, width=col_width),
curses.A_BOLD)
row += 1
else:
igroup = icol - 1
col_heading = state.column_names[igroup]
col_width = max(
len(col_heading),
len(self.value_fmt.format(0.)) + self.column_padding)
pad.addstr(row, col,
'{:>{width}}'.format(
col_heading, width=col_width),
curses.A_BOLD)
for iv, v in enumerate(state.values[igroup]):
row += 1
vstr = ' ' * self.column_padding +\
self.value_fmt.format(v)
pad.addstr(row, col, vstr)
col += col_width
self.rows = row
self.resize_pad()
pad.noutrefresh()
class Footer(_CursesPad):
def update(self, state):
pad = self.pad
pad.clear()
if not state:
return
pad.addstr(0, 0, 'Performance:')
pad.addstr(0, 14, '%.1f iter/s' % state.iter_sec)
pad.addstr(1, 0, state.text)
self.rows = 3
self.resize_pad()
pad.noutrefresh()
class Header(_CursesPad):
def update(self, state):
pad = self.pad
pad.clear()
if not state:
return
pad.addstr(0, 0, 'Problem Name:')
pad.addstr(0, 14, state.problem_name,
curses.A_BOLD)
pad.addstr(1, 0, 'Iteration:')
pad.addstr(1, 14, '%d / %d' % (state.iiter, state.niter),
curses.A_BOLD)
self.rows = 3
self.resize_pad()
pad.noutrefresh()
def __init__(self):
self.scr = None
self.state = None
curses.wrapper(self.set_screen)
self.header_pad = self.Header(self.scr.subpad(3, 100, 0, 0))
self.parameter_pad = self.ParameterTable(self.scr.subpad(3, 0))
self.footer_pad = self.Footer(self.scr.subpad(3, 100, 5, 0))
def set_screen(self, scr):
self.scr = scr
def set_state(self, state):
self.state = state
self.parameter_pad.update(state)
self.header_pad.update(state)
self.footer_pad.update(state)
self.footer_pad.pad.mvwin(
self.parameter_pad.rows + self.parameter_pad.pad.getparyx()[0] + 2,
0)
self.scr.refresh()
import progressbar as pbar
from .base import Listener
class color:
PURPLE = '\033[95m'
CYAN = '\033[96m'
DARKCYAN = '\033[36m'
BLUE = '\033[94m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
END = '\033[0m'
class TerminalListener(Listener):
col_width = 15
row_name = color.BOLD + '{:<{col_param_width}s}' + color.END
parameter_fmt = '{:>{col_width}{type}}'
def __init__(self):
self.current_state = None
self.pbars = {}
def progress_start(self, name, niter):
self.pbars[name] = pbar.start(name, niter)
def progress_update(self, name, iiter):
self.pbars[name].update(iiter)
def progress_finish(self, name):
self.pbars[name].finish()
def state(self, state):
lines = []
self.current_state = state
def l(t):
lines.append(t)
out_ln = self.row_name +\
''.join([self.parameter_fmt] * len(state.parameter_values))
col_param_width = max([len(p) for p in state.parameter_names]) + 2
l('Problem name: {s.problem_name}'
'\t({s.runtime:s} - remaining {s.runtime_remaining})'
.format(s=state))
l('Iteration {s.iiter} / {s.niter} ({s.iter_per_second:.1f} iter/s)'
.format(s=state))
l(out_ln.format(
*['Parameter'] + state.column_names,
col_param_width=col_param_width,
col_width=self.col_width,
type='s'))
for ip, parameter_name in enumerate(state.parameter_names):
l(out_ln.format(
parameter_name,
*[v[ip] for v in state.parameter_values],
col_param_width=col_param_width,
col_width=self.col_width,
type='.4g'))
l(state.extra_text.format(
col_param_width=col_param_width,
col_width=self.col_width,))
lines[0:0] = ['\033[2J']
l('')
print '\n'.join(lines)
......@@ -198,9 +198,9 @@ class Notifier(object):
def emit(self, signal_name, *args, **kwargs):
for listener in self._listeners:
try:
getattr(listener, signal_name)(*args, **kwargs)
except AttributeError:
if not hasattr(listener, signal_name):
logger.warn(
'signal name %s not implemented in listener' % signal_name)
'signal name \'%s\' not implemented in listener %s'
% (signal_name, type(listener)))
continue
getattr(listener, signal_name)(*args, **kwargs)
import logging
import time
import numpy as num
from datetime import timedelta
from pyrocko.guts import Object
......@@ -7,10 +10,62 @@ guts_prefix = 'grond'
logger = logging.getLogger('grond.solver')
class RingBuffer(num.ndarray):
def __init__(self, *args, **kwargs):
num.ndarray.__init__(self, *args, **kwargs)
self.fill(0.)
self.pos = 0
def put(self, value):
self[self.pos] = value
self.pos += 1
self.pos %= self.size
class SolverState(object):
problem_name = ''
parameter_names = []
parameter_values = []
column_names = []
extra_text = ''
niter = 0
_iiter = 0
iter_per_second = 0.
_iter_buffer = RingBuffer(20)
starttime = time.time()
_last_update = time.time()
@property
def iiter(self):
return self._iiter
@iiter.setter
def iiter(self, value):
dt = time.time() - self._last_update
self._iter_buffer.put(float((value - self._iiter) / dt))
self.iter_per_second = float(self._iter_buffer.mean())
self._iiter = value
self._last_update = time.time()
@property
def runtime(self):
return timedelta(seconds=time.time() - self.starttime)
@property
def runtime_remaining(self):
if self.iter_per_second == 0.:
return timedelta()
return timedelta(seconds=(self.niter - self.iiter)
/ self.iter_per_second)
class Solver(object):
def solve(
self, problem, rundir=None, status=(), plot=None, xs_inject=None):
state = SolverState()
def solve(
self, problem, rundir=None, status=(), plot=None, xs_inject=None,
notifier=None):
raise NotImplemented()
......@@ -22,5 +77,6 @@ class SolverConfig(Object):
__all__ = '''
Solver
SolverState
SolverConfig
'''.split()
......@@ -69,7 +69,8 @@ def solve(problem,
xs_inject=None,
status=(),
plot=None,
notifier=None):
notifier=None,
state=None):
xbounds = num.array(problem.get_parameter_bounds(), dtype=num.float)
npar = problem.nparameters
......@@ -99,6 +100,12 @@ def solve(problem,
accept_hist = num.zeros(niter, dtype=num.int)
pnames = problem.parameter_names
state.problem_name = problem.name
state.column_names = ['B mean', 'B std',
'G mean', 'G std', 'G best']
state.parameter_names = problem.parameter_names + ['Misfit']
state.niter = niter
if plot:
plot.start(problem)
......@@ -300,22 +307,12 @@ def solve(problem,
accept_sum += accept
accept_hist[iiter] = num.sum(accept)
lines = []
if 'state' in status:
lines.append('%s, %i' % (problem.name, iiter))
lines.append(''.join('-X'[int(acc)] for acc in accept))
xhist[iiter, :] = x
bxs = xhist[chains_i[:, :nlinks].ravel(), :]
gxs = xhist[chains_i[0, :nlinks], :]
gms = chains_m[0, :nlinks]
col_width = 15
col_param_width = max([len(p) for p in problem.parameter_names])+2
console_output = '{:<{col_param_width}s}'
console_output += ''.join(['{:>{col_width}{type}}'] * 5)
if nlinks > (nlinks_cap-1)/2:
# mean and std of all bootstrap ensembles together
mbx = num.mean(bxs, axis=0)
......@@ -353,52 +350,25 @@ def solve(problem,
else:
assert False, 'invalid standard_deviation_estimator choice'
if 'state' in status:
lines.append(
console_output.format(
'parameter', 'B mean', 'B std', 'G mean', 'G std',
'G best',
col_param_width=col_param_width,
col_width=col_width,
type='s'))
for (pname, mbv, sbv, mgv, sgv, bgv) in zip(
pnames, mbx, sbx, mgx, sgx, bgx):
lines.append(
console_output.format(
pname, mbv, sbv, mgv, sgv, bgv,
col_param_width=col_param_width,
col_width=col_width,
type='.4g'))
lines.append(
console_output.format(
'misfit', '', '',
'%.4g' % num.mean(gms),
'%.4g' % num.std(gms),
'%.4g' % num.min(gms),
col_param_width=col_param_width,
col_width=col_width,
type='s'))
state.parameter_values = [
num.append(mbx, num.nan),
num.append(sbx, num.nan),
num.append(mgx, num.mean(gms)),
num.append(sgx, num.std(gms)),
num.append(bgx, num.min(gms))]
state.iiter = iiter + 1
state.extra_text =\
'Phase: %s (factor %d); ntries %d, ntries_preconstrain %d'\
% (phase, factor, ntries_sample, ntries_preconstrain)
if 'state' in status:
lines.append(
console_output.format(
'iteration', iiter+1, '(%s, %g)' % (phase, factor),
ntries_sample, ntries_preconstrain, '',
col_param_width=col_param_width,
col_width=col_width,
type=''))
notifier.emit('state', state)
if 'matrix' in status:
lines = []
matrix = (chains_i[:, :30] % 94 + 32).T
for row in matrix[::-1]:
lines.append(''.join(chr(xxx) for xxx in row))
if status:
lines[0:0] = ['\033[2J']
lines.append('')
print '\n'.join(lines)
if plot and plot.want_to_update(iiter):
......@@ -433,6 +403,7 @@ class HighScoreSolver(Solver):
plot=plot,
xs_inject=xs_inject,
notifier=notifier,
state=self.state,
**self._kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment