Commit 7661c6b7 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

OPT: don't recompile cuda code for every tractography step

Achieved by calling update_pos to update the positions
in the GPU arrays (as well as updating the indices in the
forward and backwards arrays)
parent 9b606f3b
Pipeline #5286 failed with stage
in 8 minutes and 18 seconds
......@@ -14,6 +14,7 @@ from ..algorithm import Algorithm
from .. import request, algorithm
from ..utils import get_filetype
import h5py
from warnings import warn
def wrap_numba(npos, ndim, nparams, scalar=True, sparse=None):
......@@ -320,27 +321,54 @@ class BasisFunc(object):
:return: an :class:`RequestEvaluator <.evaluator.RequestEvaluator>` that has been set up to quickly evaluate the field
at the requested positions given a vector field.
"""
from .evaluator import MultEvaluator
from .evaluator import get_single_evaluator
with algorithm.set(method=method, override=True):
return MultEvaluator(self, req)
return get_single_evaluator(self, req)
def param_evaluator(self, parameters, method=None):
def param_evaluator(self, parameters, method=None, nsim=2**12):
"""Returns a function that evaluates the field at a given location
Used for streamline evaluation.
:param parameters: set of parameters for which the field should be evaluated at different positions
:param method: algorithm to use when computing field
:param nsim: number of positions to evaluate simultaneously
:return: function to map the positions to the field
"""
from .evaluator import get_single_evaluator
self.precompute_evaluator()
with algorithm.set(store_matrix=False, method=method):
sim_evaluator = get_single_evaluator(self, request.PositionRequest(np.zeros((nsim, self.ndim))))
if sim_evaluator.use_mat:
warn("Running tractography with pre-computed matrix is very slow")
def evaluate(req):
"""Evaluate the field at the requested positions
:param req: (npos, ndim) array or object defining where the field should be evaluated
:type req: request.FieldRequest
"""
return self.get_evaluator(req, method=method)(parameters, inverse=False)
self.precompute_evaluator()
if isinstance(req, request.FieldRequest):
all_pos = []
all_weights = []
for pos, weights in req.split():
all_pos.append(pos)
all_weights.append(weights)
else:
all_pos = req
all_pos = np.asanyarray(all_pos)
flat_all_pos = all_pos.reshape((-1, self.ndim))
flat_res = np.zeros(flat_all_pos.shape)
for idx in range(0, flat_all_pos.shape[0], nsim):
set_pos = flat_all_pos[idx:idx+nsim, :]
sim_evaluator.update_pos(set_pos)
flat_res[idx:idx+nsim] = sim_evaluator(parameters, inverse=False)[:set_pos.shape[0], :]
res = flat_res.reshape(all_pos.shape)
if not isinstance(req, request.FieldRequest):
return res
else:
total_res = np.sum([weight * r for weight, r in zip(all_weights, res)], 0)
return total_res
return evaluate
def validate(self, req):
......
......@@ -163,6 +163,9 @@ class RequestEvaluator(object):
"""If True evaluates the field on the GPU rather than CPU"""
return self.method in (Algorithm.cuda, Algorithm.matrix_cuda)
def update_pos(self, new_positions):
raise ValueError(f"Updating positions not implemented for {type(self)}")
class IdentityEvaluator(RequestEvaluator):
"""Returns the parameters irrespective of the basis function
......@@ -196,6 +199,9 @@ class IdentityEvaluator(RequestEvaluator):
"""
return params
def update_pos(self, new_positions):
self.request = request.PositionRequest(new_positions)
class FuncRequestEvaluator(RequestEvaluator):
"""Evaluates the basis function using on-the-fly evaluation of the matrix on the CPU or GPU
......@@ -321,6 +327,14 @@ class FuncRequestEvaluator(RequestEvaluator):
"""
self.results = {}
def update_pos(self, new_positions):
self.request = request.PositionRequest(new_positions)
for _, partial_func in self.partial_func:
if not hasattr(partial_func, 'update_pos'):
self.partial_func = [(1, self.basis.get_func(new_positions, self.method))]
break
partial_func.update_pos(new_positions)
class MatRequestEvaluator(RequestEvaluator):
"""Evaluates the basis function with a pre-defined matrix on CPU or GPU
......@@ -403,6 +417,18 @@ class MatRequestEvaluator(RequestEvaluator):
def wrap_qp(self, qp):
return self.request.wrap_qp(qp, {self.request: self.mat})
def update_pos(self, new_positions):
self.request = request.PositionRequest(new_positions)
self.mat = self.basis.get_full_mat(self.request)
if self.use_cuda:
if sparse.issparse(self.mat):
self.forward_f = cuda.sparse_mult_prepare(self.mat)
self.inverse_f = cuda.sparse_mult_prepare(self.mat.T)
else:
self.forward_f = cuda.dense_mult_prepare(self.mat)
self.inverse_f = cuda.dense_mult_prepare(self.mat.T)
class MultEvaluator(object):
"""
......@@ -599,3 +625,7 @@ class MultEvaluator(object):
@property
def use_cuda(self, ):
return any(eval.use_cuda for eval in self.evaluators.flatten())
@property
def use_mat(self, ):
return any(eval.use_mat for eval in self.evaluators.flatten())
......@@ -584,6 +584,7 @@ __global__ void matrix_mult_nd_invert({dtype} *derparam, {dtype} *derfield, int
self._main_comp(0, 1) + self._main_comp(0, 2) + self._main_comp(1, 2))
self.code = (code_comp + req_code + self._main_code).format(**params)
req_params_cuda = tuple(cuda.to_gpu_correct(req_params[name][1].astype('f8')) for name in draw_names)
self.req_params_cuda_names = draw_names
self.global_cuda_params[self.request] = req_params_cuda
basis_params_cuda = (cuda.to_gpu_correct(basis.centroids.flatten().astype('f8')),
cuda.to_gpu_correct(1. / basis.size))
......@@ -592,21 +593,28 @@ __global__ void matrix_mult_nd_invert({dtype} *derparam, {dtype} *derfield, int
self.forward = cuda.cached_func(self.code, 'matrix_mult_nd')
self.backward = cuda.cached_func(self.code, 'matrix_mult_nd_invert')
self.forward_idx = ()
self.backward_idx = ()
self.update_indices()
RadialGPUArrays.register(self.basis)
RadialGPUArrays.register(self.request)
def update_indices(self, ):
del self.forward_idx
del self.backward_idx
idx_req, idx_centroids = self.basis.within_range(self.request)
forward_compressed = sp.append(0, sp.cumsum(sp.bincount(idx_req, minlength=self.request.npos)))
forward_idx = idx_centroids[sp.argsort(idx_req)]
self.forward_idx = (cuda.to_gpu_correct(forward_compressed),
cuda.to_gpu_correct(forward_idx))
idx_req, idx_centroids = self.basis.within_range(self.request)
backward_compressed = sp.append(0, sp.cumsum(sp.bincount(idx_centroids, minlength=self.request.npos)))
backward_idx = idx_req[sp.argsort(idx_centroids)]
self.backward_idx = (cuda.to_gpu_correct(backward_compressed),
cuda.to_gpu_correct(backward_idx))
RadialGPUArrays.register(self.basis)
RadialGPUArrays.register(self.request)
def evaluate_results(self, params, inverse=False):
"""Runs the basis function for all positions in the request without accessing the results
......@@ -644,3 +652,10 @@ __global__ void matrix_mult_nd_invert({dtype} *derparam, {dtype} *derfield, int
def clean_results(self, ):
RadialGPUArrays.clean()
def update_pos(self, new_positions):
idx = self.req_params_cuda_names.index('all_pos')
self.global_cuda_params[self.request][idx][:new_positions.size] = new_positions.flatten()
self.request.positions[:new_positions.shape[0]] = new_positions
self.update_indices()
......@@ -91,9 +91,12 @@ class Fourier(BasisFunc):
}}
}}
""")
def update_pos(params, new_positions):
params['pos'][:new_positions.shape[0] * self.ndim] = new_positions.flatten()
return cuda.CudaMatrixMult(code, positions.shape[0], positions.shape[1], self.scalar_nparams,
params={'freq': (cuda.dtype + ' *freq', self.frequencies.flatten()),
'pos': (cuda.dtype + ' *pos', positions.flatten())}, scalar=True)
'pos': (cuda.dtype + ' *pos', positions.flatten())}, scalar=True,
update_pos=update_pos)
elif method == Algorithm.numba:
frequencies = self.frequencies
from numba import jit
......@@ -104,7 +107,7 @@ class Fourier(BasisFunc):
Note that this function has to be preceded by @numba.jit.
:param position: (ndim, ) array selecting a point of interest
:param idx_pos: (ndim, ) array selecting a point of interest
:param idx_param: index of the parameter to consider
:param derivative: whether to compute the derivative in a spatial dimension (set to 0, 1, or 2)
"""
......@@ -121,7 +124,9 @@ class Fourier(BasisFunc):
return derphase * np.cos(phase)
else:
return -derphase * np.sin(phase)
return func
super(Fourier, self).get_func(positions, method)
def to_hdf5(self, group: h5py.Group):
......@@ -223,6 +228,7 @@ class ChargeDistribution(BasisFunc):
norm = 1 / (4 * sp.pi * distsq)
tmp_ndim[0] = x * norm
tmp_ndim[1] = y * norm
return func
if method == Algorithm.cuda:
from .. import cuda
......@@ -254,8 +260,10 @@ class ChargeDistribution(BasisFunc):
params = {'charge_pos': (cuda.dtype + ' *charge_pos', charge_pos.flatten()),
'pos': (cuda.dtype + ' *pos', positions.flatten())}
values = {'size_squared': size_squared, 'norm_internal': norm_internal, 'inv_4pi': 1 / (4 * sp.pi)}
def update_pos(params, new_positions):
params['pos'][:new_positions.shape[0] * self.ndim] = new_positions.flatten()
return cuda.CudaMatrixMult(code, positions.shape[0], self.ndim, self.nparams, params=params,
values=values, scalar=False)
values=values, scalar=False, update_pos=update_pos)
super(ChargeDistribution, self).get_func(positions, method)
......
......@@ -280,7 +280,8 @@ def wrap_get_element(code):
class CudaMatrixMult(object):
def __init__(self, code, npos, ndim, nparams, values=None, constants=None, params=None, scalar=True):
def __init__(self, code, npos, ndim, nparams, values=None, constants=None, params=None, scalar=True,
update_pos=None):
"""
Compiles the matrix multiplication GPU kernel
......@@ -292,6 +293,7 @@ class CudaMatrixMult(object):
:param constants: dict of read-only arrays that should be uploaded to the GPU
:param params: dict of read/write arrays that should be uploaded to the GPU
:param scalar: if True each dimension of the vector field is computed independently
:param update_pos: function to update the positions. Gets passed in all the parameter GPU Arrays and the new positions
"""
import pycuda.autoinit
self.part_code = code
......@@ -309,6 +311,9 @@ class CudaMatrixMult(object):
self.param_arrs = []
for name, (identifier, arr) in self.params.items():
self.param_arrs.append(to_gpu_correct(arr))
if update_pos is not None:
self.update_pos = lambda new_positions: update_pos(
{name: arr for name, arr in zip(self.params, self.param_arrs)}, new_positions)
self.func = {}
......
......@@ -55,33 +55,39 @@ def check_computation(basis_func, npos=13, noder=False, nodiv=True):
with algorithm.set(method=method):
map1 = basis_func.get_evaluator(req)
if hasattr(map1, 'method'):
print(map1.method, method)
#print(map1.method, method)
assert map1.method == method
res1 = map1(parameters)
resalt = basis_func.param_evaluator(parameters, method=method)(req)
resalt = basis_func.param_evaluator(parameters, method=method, nsim=npos//2)(req)
if hasattr(basis_func, '_precomputed_grid'):
basis_func._precomputed_grid = None
assert res1.shape == (req.npos, ndim)
print(method, req, res1[0], resalt[0])
#print(method, req, res1[0], resalt[0])
assert sp.median(abs((res1 - resalt) / (res1 + resalt + 1e-4))) < limit
for method2 in basis_func.compute:
if method != method2:
print(f'testing {method2} for {req}')
res2 = basis_func.get_evaluator(req, method=method2)(parameters)
assert res2.shape == (req.npos, ndim)
print(method, res1[0], method2, res2[0])
#print(method, res1[0], method2, res2[0])
assert sp.median(abs((res1 - res2) / (res1 + res2 + 1e-4))) < limit, "%s and %s do not give consistent results (for %s)" % (method, method2, req)
res_param = basis_func.param_evaluator(parameters, method=method2, nsim=npos//2)(req)
print(res2, 'param', npos // 2, res_param)
assert res_param.shape == (req.npos, ndim)
assert sp.median(abs((res1 - res_param) / (res1 + res_param + 1e-4))) < limit
# checking propogation of derivatives
derf = sp.randn(npos, ndim)
derp1 = map1(derf, inverse=True)
print(method, derp1)
#print(method, derp1)
assert derp1.size == basis_func.nparams
for method2 in options:
if method != method2:
derp2 = basis_func.get_evaluator(req, method=method2)(derf, inverse=True)
assert derp1.shape == derp2.shape
print(method, derp1[0], method2, derp2[0])
#print(method, derp1[0], method2, derp2[0])
assert sp.median(abs((derp1 - derp2) / (derp1 + derp2 + 1e-3))) < 1e-3, "%s and %s do not give consistent derivatives" % (method, method2)
if nodiv:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment