Commit 0deac625 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

OPT: avoid argsort by only storing forward indices during tractography

parent 8a8e5e3c
Pipeline #5290 failed with stage
in 8 minutes and 50 seconds
......@@ -604,7 +604,7 @@ __global__ void matrix_mult_nd_invert({dtype} *derparam, {dtype} *derfield, int
RadialGPUArrays.register(self.basis)
RadialGPUArrays.register(self.request)
def update_indices(self, ):
def update_indices(self, only_forward=False):
del self.forward_idx
del self.backward_idx
idx_req, idx_centroids = self.basis.within_range(self.request)
......@@ -612,6 +612,8 @@ __global__ void matrix_mult_nd_invert({dtype} *derparam, {dtype} *derfield, int
forward_idx = idx_centroids# [sp.argsort(idx_req)]; idx_req is always sorted
self.forward_idx = (cuda.to_gpu_correct(forward_compressed),
cuda.to_gpu_correct(forward_idx))
if only_forward:
return
backward_compressed = sp.append(0, sp.cumsum(sp.bincount(idx_centroids, minlength=self.request.npos)))
backward_idx = idx_req[sp.argsort(idx_centroids)]
......@@ -660,5 +662,5 @@ __global__ void matrix_mult_nd_invert({dtype} *derparam, {dtype} *derfield, int
idx = self.req_params_cuda_names.index('all_pos')
self.global_cuda_params[self.request][idx][:new_positions.size] = new_positions.astype(cuda.dtype).flatten()
self.request.positions[:new_positions.shape[0]] = new_positions
self.update_indices()
self.update_indices(only_forward=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment