Commit ca0355fb authored by Sebastian Heimann's avatar Sebastian Heimann
Browse files

can now select standard_deviation_estimator

parent 570c9fa8
......@@ -1092,6 +1092,13 @@ class SamplerDistributionChoice(StringChoice):
choices = ['multivariate_normal', 'normal']
class StandardDeviationEstimatorChoice(StringChoice):
choices = [
'median_density_single_chain',
'standard_deviation_all_chains',
'standard_deviation_single_chain']
class SolverConfig(Object):
niter_uniform = Int.T(default=1000)
niter_transition = Int.T(default=0)
......@@ -1099,6 +1106,8 @@ class SolverConfig(Object):
niter_non_explorative = Int.T(default=0)
sampler_distribution = SamplerDistributionChoice.T(
default='multivariate_normal')
standard_deviation_estimator = StandardDeviationEstimatorChoice.T(
default='median_density_single_chain')
scatter_scale_transition = Float.T(default=2.0)
scatter_scale = Float.T(default=1.0)
chain_length_factor = Float.T(default=8.0)
......@@ -1111,6 +1120,7 @@ class SolverConfig(Object):
niter_explorative=self.niter_explorative,
niter_non_explorative=self.niter_non_explorative,
sampler_distribution=self.sampler_distribution,
standard_deviation_estimator=self.standard_deviation_estimator,
scatter_scale_transition=self.scatter_scale_transition,
scatter_scale=self.scatter_scale,
chain_length_factor=self.chain_length_factor,
......@@ -1435,6 +1445,7 @@ def solve(problem,
chain_length_factor=8.0,
xs_inject=None,
sampler_distribution='multivariate_normal',
standard_deviation_estimator='median_density_single_chain',
compensate_excentricity=True,
status=(),
plot=None):
......@@ -1662,7 +1673,7 @@ def solve(problem,
if rundir:
problem.dump_problem_data(
rundir, x, ms, ns, accept,
ibootstrap_choice if ibootstrap_choice is not None else -1,
ibootstrap_choice if ibootstrap_choice is not None else -1,
ibase if ibase is not None else -1)
accept_sum += accept
......@@ -1699,11 +1710,22 @@ def solve(problem,
xs = xhist[chains_i[i, :nlinks], :]
mx = num.mean(xs, axis=0)
cov = num.cov(xs.T)
local_sx = local_std(xs)
mxs.append(mx)
covs.append(cov)
local_sxs.append(local_sx)
if standard_deviation_estimator == \
'median_density_single_chain':
local_sx = local_std(xs)
local_sxs.append(local_sx)
elif standard_deviation_estimator == \
'standard_deviation_all_chains':
local_sxs.append(sbx)
elif standard_deviation_estimator == \
'standard_deviation_single_chain':
sx = num.std(xs, axis=0)
local_sxs.append(sx)
else:
assert False, 'invalid standard_deviation_estimator choice'
if 'state' in status:
lines.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment