Commit 42b38735 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

BUG: fix extent of param_evaluator and run on smaller npos

parent cdccc8b9
Pipeline #5298 failed with stage
in 10 minutes and 17 seconds
......@@ -325,18 +325,19 @@ class BasisFunc(object):
with algorithm.set(method=method, override=True):
return MultEvaluator(self, req)
def param_evaluator(self, parameters, method=None, nsim=2**12):
def param_evaluator(self, parameters, method=None, extent=0, nsim=2**12):
"""Returns a function that evaluates the field at a given location
Used for streamline evaluation.
:param parameters: set of parameters for which the field should be evaluated at different positions
:param method: algorithm to use when computing field
:param nsim: number of positions to evaluate simultaneously
:param extent: maximum extent of request to accept (keep at zero for tractography)
:param nsim: maximum number of positions to evaluate simultaneously
:return: function to map the positions to the field
"""
from .evaluator import MultEvaluator
self.precompute_evaluator()
self.precompute_evaluator(extent)
with algorithm.set(store_matrix=False, method=method):
sim_evaluator = MultEvaluator(self, request.PositionRequest(np.zeros((nsim, self.ndim))))
if sim_evaluator.use_mat:
......@@ -362,7 +363,8 @@ class BasisFunc(object):
for idx in range(0, flat_all_pos.shape[0], nsim):
set_pos = flat_all_pos[idx:idx+nsim, :]
sim_evaluator.update_pos(set_pos)
part_res = sim_evaluator(parameters, inverse=False)[:set_pos.shape[0], :]
part_res = sim_evaluator(parameters, inverse=False)
assert part_res.shape == set_pos.shape
if sim_evaluator.use_cuda and hasattr(part_res, 'get'):
part_res = part_res.get()
flat_res[idx:idx+nsim] = part_res
......@@ -521,22 +523,24 @@ class SumBase(UserList):
"""
self._fixed = {}
def param_evaluator(self, parameters, method=None):
def param_evaluator(self, parameters, method=None, extent=0):
"""Returns a function that evaluates the basis function for arbitrary positions given a fixed parameter array
Used to __call__ streamlines
:param parameters: (nparams, ) array defining the parameters for which the field will be evaluated
:param method: Algorithm used to __call__ the basis functions
:param extent: maximum extent of request to accept (keep at zero for tractography)
:return: function that maps positions to vector field
"""
funcs = []
idx_param = 0
for idx_elem, elem in enumerate(self):
if idx_elem in self._fixed:
funcs.append(elem.param_evaluator(self._fixed[idx_elem], method=method))
funcs.append(elem.param_evaluator(self._fixed[idx_elem], extent=extent, method=method))
else:
funcs.append(elem.param_evaluator(parameters[idx_param:idx_param + elem.nparams], method=method))
funcs.append(elem.param_evaluator(parameters[idx_param:idx_param + elem.nparams],
extent=extent, method=method))
idx_param += elem.nparams
def evaluate(positions):
......
......@@ -163,7 +163,7 @@ class RequestEvaluator(object):
"""If True evaluates the field on the GPU rather than CPU"""
return self.method in (Algorithm.cuda, Algorithm.matrix_cuda)
def update_pos(self, new_positions):
def update_pos(self, new_request):
raise ValueError(f"Updating positions not implemented for {type(self)}")
......@@ -199,8 +199,8 @@ class IdentityEvaluator(RequestEvaluator):
"""
return params
def update_pos(self, new_positions):
self.request = request.PositionRequest(new_positions)
def update_pos(self, new_request):
self.request = new_request
class FuncRequestEvaluator(RequestEvaluator):
......@@ -327,13 +327,13 @@ class FuncRequestEvaluator(RequestEvaluator):
"""
self.results = {}
def update_pos(self, new_positions):
self.request = request.PositionRequest(new_positions)
def update_pos(self, new_request):
self.request = new_request
for _, partial_func in self.partial_func:
if not hasattr(partial_func, 'update_pos'):
self.partial_func = [(1, self.basis.get_func(new_positions, self.method))]
self.partial_func = [(1, self.basis.get_func(new_request.positions, self.method))]
break
partial_func.update_pos(new_positions)
partial_func.update_pos(new_request.positions)
class MatRequestEvaluator(RequestEvaluator):
......@@ -417,8 +417,8 @@ class MatRequestEvaluator(RequestEvaluator):
def wrap_qp(self, qp):
return self.request.wrap_qp(qp, {self.request: self.mat})
def update_pos(self, new_positions):
self.request = request.PositionRequest(new_positions)
def update_pos(self, new_request):
self.request = new_request
self.mat = self.basis.get_full_mat(self.request)
if self.use_cuda:
......@@ -636,10 +636,15 @@ class MultEvaluator(object):
def update_pos(self, new_positions):
if len(self.request_list) > 1:
raise ValueError("Can't update the positions of multiple requests")
new_request = request.PositionRequest(new_positions)
for evaluator in self.evaluators.flat:
evaluator.update_pos(new_positions)
evaluator.update_pos(new_request)
for evaluator in self.fixed_field_evaluators[self.request_list[0]]:
evaluator.update_pos(new_positions)
evaluator.update_pos(new_request)
self.fixed_field_evaluators[new_request] = self.fixed_field_evaluators[self.request_list[0]]
del self.fixed_field_evaluators[self.request_list[0]]
self.request_list[0] = new_request
self.full_request = new_request
if self.fixed_field:
self.fixed_field = {
req: sp.sum([evaluator(params) for evaluator, params in self.fixed_field_evaluators[req]], 0)
......
......@@ -149,16 +149,16 @@ class RadialBasis(BasisFunc):
:return: tuple with the request and centroid indices in compressed format
"""
if self._precomputed_grids is not None:
print('check grid', req.radius(), self._precomputed_grids[0])
print(req)
if self._precomputed_grids is not None and req.radius() <= self._precomputed_grids[0]:
if not hasattr(self, '_ref_list_of_lists'):
self._ref_list_of_lists = np.zeros(req.npos, dtype='object')
empty_arr = np.zeros(0, dtype='i4')
for idx in range(req.npos):
self._ref_list_of_lists[idx] = empty_arr
list_of_lists = self._ref_list_of_lists.copy()
max_size, affine, intersects = self._precomputed_grids
if (req.radius() > max_size).any():
raise ValueError("Precomputed results only deal with maximum request radius of {}, ".format(max_size) +
"but request of {} was found".format(req.radius().max()))
list_of_lists = self._ref_list_of_lists[:req.npos].copy()
_, affine, intersects = self._precomputed_grids
voxels = np.floor(affine[:3, :3].dot(req.center().T).T + affine[:-1, -1]).astype('i4')
use = (voxels >= 0).all(-1) & (voxels < intersects.shape).all(-1)
list_of_lists[use] = intersects[tuple(voxels[use].T)]
......
......@@ -30,7 +30,7 @@ def test_sumbase():
pfull = sp.concatenate(params)
field = sp.randn(*pos.shape)
values = [bf.get_evaluator(req)(pf) for bf, pf in zip(basis_list, params)]
print(req, sp.sum(values, 0) - sb1.get_evaluator(req)(pfull))
print('request', req)
assert (sp.sum(values, 0) == sb1.get_evaluator(req)(pfull)).all()
values_rev = sp.concatenate([bf.get_evaluator(req)(field, inverse=True) for bf in basis_list])
assert values_rev.shape == pfull.shape
......@@ -38,9 +38,10 @@ def test_sumbase():
sb1.fix(1, params[1])
ppart = sp.concatenate((params[0], params[2]))
print(sp.sum(values, 0) - sb1.get_evaluator(req)(ppart))
assert (sp.sum(values, 0) - sb1.get_evaluator(req)(ppart)).max() < 1e-5
assert (sp.sum(values, 0) - sb1.param_evaluator(ppart)(req)).max() < 1e-5
assert (sp.sum(values, 0) - sb1.param_evaluator(
ppart, extent=req.radius() if isinstance(req, request.FieldRequest) else 0
)(req)).max() < 1e-5
for b in basis_list:
if hasattr(b, '_precomputed_grid'):
b._precomputed_grid = None
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment