Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Sebastian Heimann
grond
Commits
ca0355fb
Commit
ca0355fb
authored
Apr 12, 2017
by
Sebastian Heimann
Browse files
can now select standard_deviation_estimator
parent
570c9fa8
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/core.py
View file @
ca0355fb
...
@@ -1092,6 +1092,13 @@ class SamplerDistributionChoice(StringChoice):
...
@@ -1092,6 +1092,13 @@ class SamplerDistributionChoice(StringChoice):
choices
=
[
'multivariate_normal'
,
'normal'
]
choices
=
[
'multivariate_normal'
,
'normal'
]
class
StandardDeviationEstimatorChoice
(
StringChoice
):
choices
=
[
'median_density_single_chain'
,
'standard_deviation_all_chains'
,
'standard_deviation_single_chain'
]
class
SolverConfig
(
Object
):
class
SolverConfig
(
Object
):
niter_uniform
=
Int
.
T
(
default
=
1000
)
niter_uniform
=
Int
.
T
(
default
=
1000
)
niter_transition
=
Int
.
T
(
default
=
0
)
niter_transition
=
Int
.
T
(
default
=
0
)
...
@@ -1099,6 +1106,8 @@ class SolverConfig(Object):
...
@@ -1099,6 +1106,8 @@ class SolverConfig(Object):
niter_non_explorative
=
Int
.
T
(
default
=
0
)
niter_non_explorative
=
Int
.
T
(
default
=
0
)
sampler_distribution
=
SamplerDistributionChoice
.
T
(
sampler_distribution
=
SamplerDistributionChoice
.
T
(
default
=
'multivariate_normal'
)
default
=
'multivariate_normal'
)
standard_deviation_estimator
=
StandardDeviationEstimatorChoice
.
T
(
default
=
'median_density_single_chain'
)
scatter_scale_transition
=
Float
.
T
(
default
=
2.0
)
scatter_scale_transition
=
Float
.
T
(
default
=
2.0
)
scatter_scale
=
Float
.
T
(
default
=
1.0
)
scatter_scale
=
Float
.
T
(
default
=
1.0
)
chain_length_factor
=
Float
.
T
(
default
=
8.0
)
chain_length_factor
=
Float
.
T
(
default
=
8.0
)
...
@@ -1111,6 +1120,7 @@ class SolverConfig(Object):
...
@@ -1111,6 +1120,7 @@ class SolverConfig(Object):
niter_explorative
=
self
.
niter_explorative
,
niter_explorative
=
self
.
niter_explorative
,
niter_non_explorative
=
self
.
niter_non_explorative
,
niter_non_explorative
=
self
.
niter_non_explorative
,
sampler_distribution
=
self
.
sampler_distribution
,
sampler_distribution
=
self
.
sampler_distribution
,
standard_deviation_estimator
=
self
.
standard_deviation_estimator
,
scatter_scale_transition
=
self
.
scatter_scale_transition
,
scatter_scale_transition
=
self
.
scatter_scale_transition
,
scatter_scale
=
self
.
scatter_scale
,
scatter_scale
=
self
.
scatter_scale
,
chain_length_factor
=
self
.
chain_length_factor
,
chain_length_factor
=
self
.
chain_length_factor
,
...
@@ -1435,6 +1445,7 @@ def solve(problem,
...
@@ -1435,6 +1445,7 @@ def solve(problem,
chain_length_factor
=
8.0
,
chain_length_factor
=
8.0
,
xs_inject
=
None
,
xs_inject
=
None
,
sampler_distribution
=
'multivariate_normal'
,
sampler_distribution
=
'multivariate_normal'
,
standard_deviation_estimator
=
'median_density_single_chain'
,
compensate_excentricity
=
True
,
compensate_excentricity
=
True
,
status
=
(),
status
=
(),
plot
=
None
):
plot
=
None
):
...
@@ -1662,7 +1673,7 @@ def solve(problem,
...
@@ -1662,7 +1673,7 @@ def solve(problem,
if
rundir
:
if
rundir
:
problem
.
dump_problem_data
(
problem
.
dump_problem_data
(
rundir
,
x
,
ms
,
ns
,
accept
,
rundir
,
x
,
ms
,
ns
,
accept
,
ibootstrap_choice
if
ibootstrap_choice
is
not
None
else
-
1
,
ibootstrap_choice
if
ibootstrap_choice
is
not
None
else
-
1
,
ibase
if
ibase
is
not
None
else
-
1
)
ibase
if
ibase
is
not
None
else
-
1
)
accept_sum
+=
accept
accept_sum
+=
accept
...
@@ -1699,11 +1710,22 @@ def solve(problem,
...
@@ -1699,11 +1710,22 @@ def solve(problem,
xs
=
xhist
[
chains_i
[
i
,
:
nlinks
],
:]
xs
=
xhist
[
chains_i
[
i
,
:
nlinks
],
:]
mx
=
num
.
mean
(
xs
,
axis
=
0
)
mx
=
num
.
mean
(
xs
,
axis
=
0
)
cov
=
num
.
cov
(
xs
.
T
)
cov
=
num
.
cov
(
xs
.
T
)
local_sx
=
local_std
(
xs
)
mxs
.
append
(
mx
)
mxs
.
append
(
mx
)
covs
.
append
(
cov
)
covs
.
append
(
cov
)
local_sxs
.
append
(
local_sx
)
if
standard_deviation_estimator
==
\
'median_density_single_chain'
:
local_sx
=
local_std
(
xs
)
local_sxs
.
append
(
local_sx
)
elif
standard_deviation_estimator
==
\
'standard_deviation_all_chains'
:
local_sxs
.
append
(
sbx
)
elif
standard_deviation_estimator
==
\
'standard_deviation_single_chain'
:
sx
=
num
.
std
(
xs
,
axis
=
0
)
local_sxs
.
append
(
sx
)
else
:
assert
False
,
'invalid standard_deviation_estimator choice'
if
'state'
in
status
:
if
'state'
in
status
:
lines
.
append
(
lines
.
append
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment