Commit 3be1558e authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

BUG: always return MultEvaluator when requesting evaluator

parent 35e90a44
Pipeline #5288 failed with stage
in 8 minutes and 12 seconds
......@@ -321,9 +321,9 @@ class BasisFunc(object):
:return: an :class:`RequestEvaluator <.evaluator.RequestEvaluator>` that has been set up to quickly evaluate the field
at the requested positions given a vector field.
"""
from .evaluator import get_single_evaluator
from .evaluator import MultEvaluator
with algorithm.set(method=method, override=True):
return get_single_evaluator(self, req)
return MultEvaluator(self, req)
def param_evaluator(self, parameters, method=None, nsim=2**12):
"""Returns a function that evaluates the field at a given location
......@@ -335,10 +335,10 @@ class BasisFunc(object):
:param nsim: number of positions to evaluate simultaneously
:return: function to map the positions to the field
"""
from .evaluator import get_single_evaluator
from .evaluator import MultEvaluator
self.precompute_evaluator()
with algorithm.set(store_matrix=False, method=method):
sim_evaluator = get_single_evaluator(self, request.PositionRequest(np.zeros((nsim, self.ndim))))
sim_evaluator = MultEvaluator(self, request.PositionRequest(np.zeros((nsim, self.ndim))))
if sim_evaluator.use_mat:
warn("Running tractography with pre-computed matrix is very slow")
......
......@@ -464,10 +464,13 @@ class MultEvaluator(object):
assert self.nparams == self.sum_basis.nparams
self.request_list = list(set(self.full_request.flatten()))
self.fixed_field_evaluators = {req: [(bf.get_evaluator(req), params) for bf, params in fixed_field]
for req in self.request_list}
if len(fixed_field) != 0:
self.fixed_field = {req: sp.sum([bf.get_evaluator(req)(params)
for bf, params in fixed_field], 0)
for req in self.request_list}
self.fixed_field = {
req: sp.sum([evaluator(params) for evaluator, params in self.fixed_field_evaluators[req]], 0)
for req in self.request_list
}
self.evaluators = sp.zeros((len(self.basis_list), len(self.request_list)), dtype='object')
for idxb, (basis, _) in enumerate(self.basis_list):
......@@ -629,3 +632,16 @@ class MultEvaluator(object):
@property
def use_mat(self, ):
return any(eval.use_mat for eval in self.evaluators.flatten())
def update_pos(self, new_positions):
if len(self.request_list) > 1:
raise ValueError("Can't update the positions of multiple requests")
for evaluator in self.evaluators.flat:
evaluator.update_pos(new_positions)
for evaluator in self.fixed_field_evaluators[self.request_list[0]]:
evaluator.update_pos(new_positions)
if self.fixed_field:
self.fixed_field = {
req: sp.sum([evaluator(params) for evaluator, params in self.fixed_field_evaluators[req]], 0)
for req in self.request_list
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment