Skip to content
Snippets Groups Projects
Commit c06bbe5d authored by Moises Fernandez's avatar Moises Fernandez
Browse files

Join MCMC kernels (simplify code), avoid Static Declaration of arrays (moved...

Join MCMC kernels (simplify code), avoid Static Declaration of arrays (moved to shared or Global Memory) and reduce precision in some structures (double to float)
parent 356d0912
No related branches found
No related tags found
No related merge requests found
...@@ -25,10 +25,10 @@ using namespace Xfibres; ...@@ -25,10 +25,10 @@ using namespace Xfibres;
////////////////////////////////////////////////////// //////////////////////////////////////////////////////
void init_Fibres_Multifibres( //INPUT void init_Fibres_Multifibres( //INPUT
thrust::device_vector<double> datam_gpu, thrust::device_vector<float> datam_gpu,
thrust::device_vector<double> params_gpu, thrust::device_vector<float> params_gpu,
thrust::device_vector<float> tau_gpu, thrust::device_vector<float> tau_gpu,
thrust::device_vector<double> bvals_gpu, thrust::device_vector<float> bvals_gpu,
thrust::device_vector<double> alpha_gpu, thrust::device_vector<double> alpha_gpu,
thrust::device_vector<double> beta_gpu, thrust::device_vector<double> beta_gpu,
const int ndirections, const int ndirections,
...@@ -55,28 +55,33 @@ void init_Fibres_Multifibres( //INPUT ...@@ -55,28 +55,33 @@ void init_Fibres_Multifibres( //INPUT
if(opts.modelnum.value()==2) nparams_fit++; if(opts.modelnum.value()==2) nparams_fit++;
if(opts.f0.value()) nparams_fit++; if(opts.f0.value()) nparams_fit++;
thrust::device_vector<double> angtmp_gpu;
angtmp_gpu.resize(nvox*ndirections*nfib);
bool gradnonlin = opts.grad_file.set(); bool gradnonlin = opts.grad_file.set();
int blocks = nvox; int blocks = nvox;
dim3 Dim_Grid_MCMC(blocks, 1); dim3 Dim_Grid_MCMC(blocks, 1);
dim3 Dim_Block_MCMC(THREADS_BLOCK_MCMC ,1); ///dimensions for MCMC dim3 Dim_Block_MCMC(THREADS_BLOCK_MCMC ,1); ///dimensions for MCMC
double *datam_ptr = thrust::raw_pointer_cast(datam_gpu.data()); float *datam_ptr = thrust::raw_pointer_cast(datam_gpu.data());
double *params_ptr = thrust::raw_pointer_cast(params_gpu.data()); float *params_ptr = thrust::raw_pointer_cast(params_gpu.data());
float *tau_ptr = thrust::raw_pointer_cast(tau_gpu.data()); float *tau_ptr = thrust::raw_pointer_cast(tau_gpu.data());
double *bvals_ptr = thrust::raw_pointer_cast(bvals_gpu.data()); float *bvals_ptr = thrust::raw_pointer_cast(bvals_gpu.data());
double *alpha_ptr = thrust::raw_pointer_cast(alpha_gpu.data()); double *alpha_ptr = thrust::raw_pointer_cast(alpha_gpu.data());
double *beta_ptr = thrust::raw_pointer_cast(beta_gpu.data()); double *beta_ptr = thrust::raw_pointer_cast(beta_gpu.data());
FibreGPU *fibres_ptr = thrust::raw_pointer_cast(fibres_gpu.data()); FibreGPU *fibres_ptr = thrust::raw_pointer_cast(fibres_gpu.data());
MultifibreGPU *multifibres_ptr = thrust::raw_pointer_cast(multifibres_gpu.data()); MultifibreGPU *multifibres_ptr = thrust::raw_pointer_cast(multifibres_gpu.data());
double *signals_ptr = thrust::raw_pointer_cast(signals_gpu.data()); double *signals_ptr = thrust::raw_pointer_cast(signals_gpu.data());
double *isosignals_ptr = thrust::raw_pointer_cast(isosignals_gpu.data()); double *isosignals_ptr = thrust::raw_pointer_cast(isosignals_gpu.data());
double *angtmp_ptr = thrust::raw_pointer_cast(angtmp_gpu.data());
int amount_shared = (THREADS_BLOCK_MCMC+1)*sizeof(double) + (3*nfib + 8)*sizeof(float); int amount_shared = (THREADS_BLOCK_MCMC)*sizeof(double) + (3*nfib + 8)*sizeof(float) + sizeof(int);
myfile << "Shared Memory Used in init_Fibres_Multifibres: " << amount_shared << "\n"; myfile << "Shared Memory Used in init_Fibres_Multifibres: " << amount_shared << "\n";
init_Fibres_Multifibres_kernel<<< Dim_Grid_MCMC, Dim_Block_MCMC, amount_shared>>>(datam_ptr, params_ptr, tau_ptr, bvals_ptr, alpha_ptr, beta_ptr, ndirections, nfib, nparams_fit, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.rician.value(), opts.ardf0.value(), opts.all_ard.value(), opts.no_ard.value(), gradnonlin, fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr); init_Fibres_Multifibres_kernel<<< Dim_Grid_MCMC, Dim_Block_MCMC, amount_shared>>>(datam_ptr, params_ptr, tau_ptr, bvals_ptr, alpha_ptr, beta_ptr, ndirections, nfib, nparams_fit, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.rician.value(), opts.ardf0.value(), opts.all_ard.value(), opts.no_ard.value(), gradnonlin, angtmp_ptr, fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr);
sync_check("init_Fibres_Multifibres_kernel"); sync_check("init_Fibres_Multifibres_kernel");
gettimeofday(&t2,NULL); gettimeofday(&t2,NULL);
...@@ -87,8 +92,8 @@ void init_Fibres_Multifibres( //INPUT ...@@ -87,8 +92,8 @@ void init_Fibres_Multifibres( //INPUT
} }
void runmcmc_burnin( //INPUT void runmcmc_burnin( //INPUT
thrust::device_vector<double> datam_gpu, thrust::device_vector<float> datam_gpu,
thrust::device_vector<double> bvals_gpu, thrust::device_vector<float> bvals_gpu,
thrust::device_vector<double> alpha_gpu, thrust::device_vector<double> alpha_gpu,
thrust::device_vector<double> beta_gpu, thrust::device_vector<double> beta_gpu,
const int ndirections, const int ndirections,
...@@ -126,6 +131,19 @@ void runmcmc_burnin( //INPUT ...@@ -126,6 +131,19 @@ void runmcmc_burnin( //INPUT
else nparams=2+nfib*3; else nparams=2+nfib*3;
if(opts.modelnum.value()==2) nparams++; if(opts.modelnum.value()==2) nparams++;
if(opts.rician.value()) nparams++; if(opts.rician.value()) nparams++;
thrust::device_vector<float> recors_null_gpu;
recors_null_gpu.resize(1);
thrust::device_vector<double> angtmp_gpu;
thrust::device_vector<double> oldangtmp_gpu;
thrust::device_vector<double> oldsignals_gpu;
thrust::device_vector<double> oldisosignals_gpu;
angtmp_gpu.resize(nvox*ndirections*nfib);
oldangtmp_gpu.resize(nvox*ndirections);
oldsignals_gpu.resize(nvox*ndirections*nfib);
oldisosignals_gpu.resize(nvox*ndirections);
unsigned int totalrandoms=(opts.nburn.value() * nvox * nparams); unsigned int totalrandoms=(opts.nburn.value() * nvox * nparams);
...@@ -187,8 +205,8 @@ void runmcmc_burnin( //INPUT ...@@ -187,8 +205,8 @@ void runmcmc_burnin( //INPUT
curandSetPseudoRandomGeneratorSeed(gen,seed); curandSetPseudoRandomGeneratorSeed(gen,seed);
//get pointers //get pointers
double *datam_ptr = thrust::raw_pointer_cast(datam_gpu.data()); float *datam_ptr = thrust::raw_pointer_cast(datam_gpu.data());
double *bvals_ptr = thrust::raw_pointer_cast(bvals_gpu.data()); float *bvals_ptr = thrust::raw_pointer_cast(bvals_gpu.data());
double *alpha_ptr = thrust::raw_pointer_cast(alpha_gpu.data()); double *alpha_ptr = thrust::raw_pointer_cast(alpha_gpu.data());
double *beta_ptr = thrust::raw_pointer_cast(beta_gpu.data()); double *beta_ptr = thrust::raw_pointer_cast(beta_gpu.data());
float *randomsN_ptr = thrust::raw_pointer_cast(randomsN_gpu.data()); float *randomsN_ptr = thrust::raw_pointer_cast(randomsN_gpu.data());
...@@ -198,7 +216,14 @@ void runmcmc_burnin( //INPUT ...@@ -198,7 +216,14 @@ void runmcmc_burnin( //INPUT
double *signals_ptr = thrust::raw_pointer_cast(signals_gpu.data()); double *signals_ptr = thrust::raw_pointer_cast(signals_gpu.data());
double *isosignals_ptr = thrust::raw_pointer_cast(isosignals_gpu.data()); double *isosignals_ptr = thrust::raw_pointer_cast(isosignals_gpu.data());
int amount_shared = (THREADS_BLOCK_MCMC+1)*sizeof(double) + (10*nfib + 2*nparams + 24)*sizeof(float) + (7*nfib + 16)*sizeof(int); double *angtmp_ptr = thrust::raw_pointer_cast(angtmp_gpu.data());
double *oldangtmp_ptr = thrust::raw_pointer_cast(oldangtmp_gpu.data());
double *oldsignals_ptr = thrust::raw_pointer_cast(oldsignals_gpu.data());
double *oldisosignals_ptr = thrust::raw_pointer_cast(oldisosignals_gpu.data());
float *records_null = thrust::raw_pointer_cast(recors_null_gpu.data());
int amount_shared = (THREADS_BLOCK_MCMC)*sizeof(double) + (10*nfib + 2*nparams + 24)*sizeof(float) + (7*nfib + 19)*sizeof(int);
myfile << "Shared Memory Used in runmcmc_burnin: " << amount_shared << "\n"; myfile << "Shared Memory Used in runmcmc_burnin: " << amount_shared << "\n";
...@@ -224,7 +249,7 @@ void runmcmc_burnin( //INPUT ...@@ -224,7 +249,7 @@ void runmcmc_burnin( //INPUT
gettimeofday(&t1,NULL); gettimeofday(&t1,NULL);
runmcmc_burnin_kernel<<< Dim_Grid, Dim_Block, amount_shared >>>(datam_ptr, bvals_ptr, alpha_ptr, beta_ptr, randomsN_ptr, randomsU_ptr, ndirections, nfib, nparams, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.ardf0.value(), !opts.no_ard.value(), opts.rician.value(), gradnonlin, opts.updateproposalevery.value(), iters_step, (i*iters_step), fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr); runmcmc_kernel<<< Dim_Grid, Dim_Block, amount_shared >>>(datam_ptr, bvals_ptr, alpha_ptr, beta_ptr, randomsN_ptr, randomsU_ptr, ndirections, nfib, nparams, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.ardf0.value(), !opts.no_ard.value(), opts.rician.value(), gradnonlin, opts.updateproposalevery.value(), iters_step, (i*iters_step), 0, 0, 0, oldsignals_ptr, oldisosignals_ptr, angtmp_ptr, oldangtmp_ptr, fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr,records_null,records_null,records_null,records_null,records_null,records_null,records_null, records_null);
sync_check("runmcmc_burnin_kernel"); sync_check("runmcmc_burnin_kernel");
gettimeofday(&t2,NULL); gettimeofday(&t2,NULL);
...@@ -254,7 +279,7 @@ void runmcmc_burnin( //INPUT ...@@ -254,7 +279,7 @@ void runmcmc_burnin( //INPUT
gettimeofday(&t1,NULL); gettimeofday(&t1,NULL);
if(nvox!=0){ if(nvox!=0){
runmcmc_burnin_kernel<<< Dim_Grid, Dim_Block, amount_shared >>>(datam_ptr, bvals_ptr, alpha_ptr, beta_ptr, randomsN_ptr, randomsU_ptr, ndirections, nfib, nparams, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.ardf0.value(), !opts.no_ard.value(), opts.rician.value(), gradnonlin, opts.updateproposalevery.value(), last_step, (steps*iters_step), fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr); runmcmc_kernel<<< Dim_Grid, Dim_Block, amount_shared >>>(datam_ptr, bvals_ptr, alpha_ptr, beta_ptr, randomsN_ptr, randomsU_ptr, ndirections, nfib, nparams, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.ardf0.value(), !opts.no_ard.value(), opts.rician.value(), gradnonlin, opts.updateproposalevery.value(), last_step, (steps*iters_step), 0, 0, 0, oldsignals_ptr, oldisosignals_ptr, angtmp_ptr, oldangtmp_ptr, fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr,records_null,records_null,records_null,records_null,records_null,records_null, records_null,records_null);
sync_check("runmcmc_burnin_kernel"); sync_check("runmcmc_burnin_kernel");
} }
...@@ -277,8 +302,8 @@ void runmcmc_burnin( //INPUT ...@@ -277,8 +302,8 @@ void runmcmc_burnin( //INPUT
void runmcmc_record( //INPUT void runmcmc_record( //INPUT
thrust::device_vector<double> datam_gpu, thrust::device_vector<float> datam_gpu,
thrust::device_vector<double> bvals_gpu, thrust::device_vector<float> bvals_gpu,
thrust::device_vector<double> alpha_gpu, thrust::device_vector<double> alpha_gpu,
thrust::device_vector<double> beta_gpu, thrust::device_vector<double> beta_gpu,
thrust::device_vector<FibreGPU> fibres_gpu, thrust::device_vector<FibreGPU> fibres_gpu,
...@@ -326,6 +351,16 @@ void runmcmc_record( //INPUT ...@@ -326,6 +351,16 @@ void runmcmc_record( //INPUT
else nparams=2+nfib*3; else nparams=2+nfib*3;
if(opts.modelnum.value()==2) nparams++; if(opts.modelnum.value()==2) nparams++;
if(opts.rician.value()) nparams++; if(opts.rician.value()) nparams++;
thrust::device_vector<double> angtmp_gpu;
thrust::device_vector<double> oldangtmp_gpu;
thrust::device_vector<double> oldsignals_gpu;
thrust::device_vector<double> oldisosignals_gpu;
angtmp_gpu.resize(nvox*ndirections*nfib);
oldangtmp_gpu.resize(nvox*ndirections);
oldsignals_gpu.resize(nvox*ndirections*nfib);
oldisosignals_gpu.resize(nvox*ndirections);
unsigned int totalrandoms=(opts.njumps.value() * nvox * nparams); unsigned int totalrandoms=(opts.njumps.value() * nvox * nparams);
...@@ -387,8 +422,8 @@ void runmcmc_record( //INPUT ...@@ -387,8 +422,8 @@ void runmcmc_record( //INPUT
curandSetPseudoRandomGeneratorSeed(gen,seed); curandSetPseudoRandomGeneratorSeed(gen,seed);
//get pointers //get pointers
double *datam_ptr = thrust::raw_pointer_cast(datam_gpu.data()); float *datam_ptr = thrust::raw_pointer_cast(datam_gpu.data());
double *bvals_ptr = thrust::raw_pointer_cast(bvals_gpu.data()); float *bvals_ptr = thrust::raw_pointer_cast(bvals_gpu.data());
double *alpha_ptr = thrust::raw_pointer_cast(alpha_gpu.data()); double *alpha_ptr = thrust::raw_pointer_cast(alpha_gpu.data());
double *beta_ptr = thrust::raw_pointer_cast(beta_gpu.data()); double *beta_ptr = thrust::raw_pointer_cast(beta_gpu.data());
float *randomsN_ptr = thrust::raw_pointer_cast(randomsN_gpu.data()); float *randomsN_ptr = thrust::raw_pointer_cast(randomsN_gpu.data());
...@@ -397,6 +432,11 @@ void runmcmc_record( //INPUT ...@@ -397,6 +432,11 @@ void runmcmc_record( //INPUT
MultifibreGPU *multifibres_ptr = thrust::raw_pointer_cast(multifibres_gpu.data()); MultifibreGPU *multifibres_ptr = thrust::raw_pointer_cast(multifibres_gpu.data());
double *signals_ptr = thrust::raw_pointer_cast(signals_gpu.data()); double *signals_ptr = thrust::raw_pointer_cast(signals_gpu.data());
double *isosignals_ptr = thrust::raw_pointer_cast(isosignals_gpu.data()); double *isosignals_ptr = thrust::raw_pointer_cast(isosignals_gpu.data());
double *angtmp_ptr = thrust::raw_pointer_cast(angtmp_gpu.data());
double *oldangtmp_ptr = thrust::raw_pointer_cast(oldangtmp_gpu.data());
double *oldsignals_ptr = thrust::raw_pointer_cast(oldsignals_gpu.data());
double *oldisosignals_ptr = thrust::raw_pointer_cast(oldisosignals_gpu.data());
float *rf0_ptr = thrust::raw_pointer_cast(rf0_gpu.data()); float *rf0_ptr = thrust::raw_pointer_cast(rf0_gpu.data());
float *rtau_ptr = thrust::raw_pointer_cast(rtau_gpu.data()); float *rtau_ptr = thrust::raw_pointer_cast(rtau_gpu.data());
...@@ -407,7 +447,7 @@ void runmcmc_record( //INPUT ...@@ -407,7 +447,7 @@ void runmcmc_record( //INPUT
float *rph_ptr = thrust::raw_pointer_cast(rph_gpu.data()); float *rph_ptr = thrust::raw_pointer_cast(rph_gpu.data());
float *rf_ptr = thrust::raw_pointer_cast(rf_gpu.data()); float *rf_ptr = thrust::raw_pointer_cast(rf_gpu.data());
int amount_shared = (THREADS_BLOCK_MCMC+1)*sizeof(double) + (10*nfib + 2*nparams + 24)*sizeof(float) + (7*nfib + 18)*sizeof(int); int amount_shared = (THREADS_BLOCK_MCMC)*sizeof(double) + (10*nfib + 2*nparams + 24)*sizeof(float) + (7*nfib + 19)*sizeof(int);
myfile << "Shared Memory Used in runmcmc_record: " << amount_shared << "\n"; myfile << "Shared Memory Used in runmcmc_record: " << amount_shared << "\n";
...@@ -433,7 +473,7 @@ void runmcmc_record( //INPUT ...@@ -433,7 +473,7 @@ void runmcmc_record( //INPUT
gettimeofday(&t1,NULL); gettimeofday(&t1,NULL);
runmcmc_record_kernel<<< Dim_Grid, Dim_Block, amount_shared >>>(datam_ptr, bvals_ptr, alpha_ptr, beta_ptr, fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr, randomsN_ptr, randomsU_ptr, ndirections, nfib, nparams, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.ardf0.value(), !opts.no_ard.value(), opts.rician.value(), gradnonlin, opts.updateproposalevery.value(), iters_step, (i*iters_step), opts.nburn.value(), opts.sampleevery.value(), totalrecords, rf0_ptr, rtau_ptr, rs0_ptr, rd_ptr, rdstd_ptr, rth_ptr, rph_ptr, rf_ptr); runmcmc_kernel<<< Dim_Grid, Dim_Block, amount_shared >>>(datam_ptr, bvals_ptr, alpha_ptr, beta_ptr, randomsN_ptr, randomsU_ptr, ndirections, nfib, nparams, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.ardf0.value(), !opts.no_ard.value(), opts.rician.value(), gradnonlin, opts.updateproposalevery.value(), iters_step, (i*iters_step), opts.nburn.value(), opts.sampleevery.value(), totalrecords, oldsignals_ptr, oldisosignals_ptr, angtmp_ptr, oldangtmp_ptr, fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr, rf0_ptr, rtau_ptr, rs0_ptr, rd_ptr, rdstd_ptr, rth_ptr, rph_ptr, rf_ptr);
sync_check("runmcmc_record_kernel"); sync_check("runmcmc_record_kernel");
gettimeofday(&t2,NULL); gettimeofday(&t2,NULL);
...@@ -463,7 +503,7 @@ void runmcmc_record( //INPUT ...@@ -463,7 +503,7 @@ void runmcmc_record( //INPUT
gettimeofday(&t1,NULL); gettimeofday(&t1,NULL);
if(nvox!=0){ if(nvox!=0){
runmcmc_record_kernel<<< Dim_Grid, Dim_Block, amount_shared >>>(datam_ptr, bvals_ptr, alpha_ptr, beta_ptr, fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr, randomsN_ptr, randomsU_ptr, ndirections, nfib, nparams, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.ardf0.value(), !opts.no_ard.value(), opts.rician.value(), gradnonlin, opts.updateproposalevery.value(), last_step, (steps*iters_step), opts.nburn.value(), opts.sampleevery.value(), totalrecords,rf0_ptr, rtau_ptr, rs0_ptr, rd_ptr, rdstd_ptr, rth_ptr, rph_ptr, rf_ptr); runmcmc_kernel<<< Dim_Grid, Dim_Block, amount_shared >>>(datam_ptr, bvals_ptr, alpha_ptr, beta_ptr,randomsN_ptr, randomsU_ptr, ndirections, nfib, nparams, opts.modelnum.value(), opts.fudge.value(), opts.f0.value(), opts.ardf0.value(), !opts.no_ard.value(), opts.rician.value(), gradnonlin, opts.updateproposalevery.value(), last_step, (steps*iters_step), opts.nburn.value(), opts.sampleevery.value(), totalrecords, oldsignals_ptr, oldisosignals_ptr, angtmp_ptr, oldangtmp_ptr, fibres_ptr, multifibres_ptr, signals_ptr, isosignals_ptr, rf0_ptr, rtau_ptr, rs0_ptr, rd_ptr, rdstd_ptr, rth_ptr, rph_ptr, rf_ptr);
sync_check("runmcmc_record_kernel"); sync_check("runmcmc_record_kernel");
} }
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment