Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Sebastian Heimann
grond
Commits
a413f383
Commit
a413f383
authored
Jun 09, 2016
by
Sebastian Heimann
Browse files
implement inter-group weighting
parent
44964da4
Changes
2
Hide whitespace changes
Inline
Side-by-side
src/cmt.py
View file @
a413f383
...
...
@@ -220,8 +220,31 @@ class CMTProblem(core.Problem):
return
self
.
_target_weights
def
inter_group_weights
(
self
,
ns
):
group
,
ngroups
=
self
.
get_group_mask
()
ws
=
num
.
zeros
(
self
.
ntargets
)
for
igroup
in
xrange
(
ngroups
):
mask
=
group
==
igroup
ws
[
mask
]
=
1.0
/
num
.
sqrt
(
num
.
nansum
(
ns
[
mask
]
**
2
))
return
ws
def
inter_group_weights2
(
self
,
ns
):
group
,
ngroups
=
self
.
get_group_mask
()
ws
=
num
.
zeros
(
ns
.
shape
)
for
igroup
in
xrange
(
ngroups
):
mask
=
group
==
igroup
ws
[:,
mask
]
=
(
1.0
/
num
.
sqrt
(
num
.
nansum
(
ns
[:,
mask
]
**
2
,
axis
=
1
)))[:,
num
.
newaxis
]
return
ws
def
bootstrap_misfit
(
self
,
ms
,
ns
,
ibootstrap
=
None
):
w
=
self
.
get_bootstrap_weights
(
ibootstrap
)
*
self
.
get_target_weights
()
w
=
self
.
get_bootstrap_weights
(
ibootstrap
)
*
\
self
.
get_target_weights
()
*
self
.
inter_group_weights
(
ns
)
if
ibootstrap
is
None
:
return
num
.
sqrt
(
num
.
nansum
((
w
*
ms
[
num
.
newaxis
,
:])
**
2
,
axis
=
1
)
/
...
...
@@ -231,25 +254,30 @@ class CMTProblem(core.Problem):
def
bootstrap_misfits
(
self
,
misfits
,
ibootstrap
):
w
=
self
.
get_bootstrap_weights
(
ibootstrap
)[
num
.
newaxis
,
:]
*
\
self
.
get_target_weights
()[
num
.
newaxis
,
:]
self
.
get_target_weights
()[
num
.
newaxis
,
:]
*
\
self
.
inter_group_weights2
(
misfits
[:,
:,
1
])
bms
=
num
.
sqrt
(
num
.
nansum
((
w
*
misfits
[:,
:,
0
])
**
2
,
axis
=
1
)
/
num
.
nansum
((
w
*
misfits
[:,
:,
1
])
**
2
,
axis
=
1
))
return
bms
def
global_misfit
(
self
,
ms
,
ns
):
ws
=
self
.
get_target_weights
()
ws
=
self
.
get_target_weights
()
*
self
.
inter_group_weights
(
ns
)
m
=
num
.
sqrt
(
num
.
nansum
((
ws
*
ms
)
**
2
)
/
num
.
nansum
((
ws
*
ns
)
**
2
))
return
m
def
global_misfits
(
self
,
misfits
):
ws
=
self
.
get_target_weights
()[
num
.
newaxis
,
:]
ws
=
self
.
get_target_weights
()[
num
.
newaxis
,
:]
*
\
self
.
inter_group_weights2
(
misfits
[:,
:,
1
])
gms
=
num
.
sqrt
(
num
.
nansum
((
ws
*
misfits
[:,
:,
0
])
**
2
,
axis
=
1
)
/
num
.
nansum
((
ws
*
misfits
[:,
:,
1
])
**
2
,
axis
=
1
))
return
gms
def
global_contributions
(
self
,
misfits
):
ws
=
self
.
get_target_weights
()[
num
.
newaxis
,
:]
ws
=
self
.
get_target_weights
()[
num
.
newaxis
,
:]
*
\
self
.
inter_group_weights2
(
misfits
[:,
:,
1
])
gcms
=
(
ws
*
misfits
[:,
:,
0
])
**
2
/
\
num
.
nansum
((
ws
*
misfits
[:,
:,
1
])
**
2
,
axis
=
1
)[:,
num
.
newaxis
]
...
...
src/core.py
View file @
a413f383
...
...
@@ -128,6 +128,7 @@ class Problem(Object):
self
.
_bootstrap_weights
=
None
self
.
_target_weights
=
None
self
.
_engine
=
None
self
.
_group_mask
=
None
def
get_engine
(
self
):
return
self
.
_engine
...
...
@@ -211,6 +212,26 @@ class Problem(Object):
def
set_engine
(
self
,
engine
):
self
.
_engine
=
engine
def
make_group_mask
(
self
):
super_group_names
=
set
()
groups
=
num
.
zeros
(
len
(
self
.
targets
),
dtype
=
num
.
int
)
ngroups
=
0
for
itarget
,
target
in
enumerate
(
self
.
targets
):
if
target
.
super_group
not
in
super_group_names
:
super_group_names
.
add
(
target
.
super_group
)
ngroups
+=
1
groups
[
itarget
]
=
ngroups
-
1
ngroups
+=
1
return
groups
,
ngroups
def
get_group_mask
(
self
):
if
self
.
_group_mask
is
None
:
self
.
_group_mask
=
self
.
make_group_mask
()
return
self
.
_group_mask
class
ProblemConfig
(
Object
):
name_template
=
String
.
T
()
...
...
@@ -1074,17 +1095,7 @@ def analyse(problem, niter=1000, show_progress=False):
wtarget
.
weight
=
1.0
wtargets
.
append
(
wtarget
)
super_group_names
=
set
()
groups
=
num
.
zeros
(
len
(
problem
.
targets
),
dtype
=
num
.
int
)
ngroups
=
0
for
itarget
,
target
in
enumerate
(
problem
.
targets
):
if
target
.
super_group
not
in
super_group_names
:
super_group_names
.
add
(
target
.
super_group
)
ngroups
+=
1
groups
[
itarget
]
=
ngroups
-
1
ngroups
+=
1
groups
,
ngroups
=
problem
.
get_group_mask
()
wproblem
=
problem
.
copy
()
wproblem
.
targets
=
wtargets
...
...
@@ -1094,7 +1105,6 @@ def analyse(problem, niter=1000, show_progress=False):
mss
=
num
.
zeros
((
niter
,
problem
.
ntargets
))
rstate
=
num
.
random
.
RandomState
(
123
)
print
groups
if
show_progress
:
pbar
=
util
.
progressbar
(
'analysing problem'
,
niter
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment