Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Michiel Cottaar
gyral_structure
Commits
3be1558e
Commit
3be1558e
authored
May 03, 2020
by
Michiel Cottaar
Browse files
BUG: always return MultEvaluator when requesting evaluator
parent
35e90a44
Pipeline
#5288
failed with stage
in 8 minutes and 12 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
gyral_structure/basis/core.py
View file @
3be1558e
...
...
@@ -321,9 +321,9 @@ class BasisFunc(object):
:return: an :class:`RequestEvaluator <.evaluator.RequestEvaluator>` that has been set up to quickly evaluate the field
at the requested positions given a vector field.
"""
from
.evaluator
import
get_single_e
valuator
from
.evaluator
import
MultE
valuator
with
algorithm
.
set
(
method
=
method
,
override
=
True
):
return
get_single_e
valuator
(
self
,
req
)
return
MultE
valuator
(
self
,
req
)
def
param_evaluator
(
self
,
parameters
,
method
=
None
,
nsim
=
2
**
12
):
"""Returns a function that evaluates the field at a given location
...
...
@@ -335,10 +335,10 @@ class BasisFunc(object):
:param nsim: number of positions to evaluate simultaneously
:return: function to map the positions to the field
"""
from
.evaluator
import
get_single_e
valuator
from
.evaluator
import
MultE
valuator
self
.
precompute_evaluator
()
with
algorithm
.
set
(
store_matrix
=
False
,
method
=
method
):
sim_evaluator
=
get_single_e
valuator
(
self
,
request
.
PositionRequest
(
np
.
zeros
((
nsim
,
self
.
ndim
))))
sim_evaluator
=
MultE
valuator
(
self
,
request
.
PositionRequest
(
np
.
zeros
((
nsim
,
self
.
ndim
))))
if
sim_evaluator
.
use_mat
:
warn
(
"Running tractography with pre-computed matrix is very slow"
)
...
...
gyral_structure/basis/evaluator.py
View file @
3be1558e
...
...
@@ -464,10 +464,13 @@ class MultEvaluator(object):
assert
self
.
nparams
==
self
.
sum_basis
.
nparams
self
.
request_list
=
list
(
set
(
self
.
full_request
.
flatten
()))
self
.
fixed_field_evaluators
=
{
req
:
[(
bf
.
get_evaluator
(
req
),
params
)
for
bf
,
params
in
fixed_field
]
for
req
in
self
.
request_list
}
if
len
(
fixed_field
)
!=
0
:
self
.
fixed_field
=
{
req
:
sp
.
sum
([
bf
.
get_evaluator
(
req
)(
params
)
for
bf
,
params
in
fixed_field
],
0
)
for
req
in
self
.
request_list
}
self
.
fixed_field
=
{
req
:
sp
.
sum
([
evaluator
(
params
)
for
evaluator
,
params
in
self
.
fixed_field_evaluators
[
req
]],
0
)
for
req
in
self
.
request_list
}
self
.
evaluators
=
sp
.
zeros
((
len
(
self
.
basis_list
),
len
(
self
.
request_list
)),
dtype
=
'object'
)
for
idxb
,
(
basis
,
_
)
in
enumerate
(
self
.
basis_list
):
...
...
@@ -629,3 +632,16 @@ class MultEvaluator(object):
@
property
def
use_mat
(
self
,
):
return
any
(
eval
.
use_mat
for
eval
in
self
.
evaluators
.
flatten
())
def
update_pos
(
self
,
new_positions
):
if
len
(
self
.
request_list
)
>
1
:
raise
ValueError
(
"Can't update the positions of multiple requests"
)
for
evaluator
in
self
.
evaluators
.
flat
:
evaluator
.
update_pos
(
new_positions
)
for
evaluator
in
self
.
fixed_field_evaluators
[
self
.
request_list
[
0
]]:
evaluator
.
update_pos
(
new_positions
)
if
self
.
fixed_field
:
self
.
fixed_field
=
{
req
:
sp
.
sum
([
evaluator
(
params
)
for
evaluator
,
params
in
self
.
fixed_field_evaluators
[
req
]],
0
)
for
req
in
self
.
request_list
}
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment