diff --git a/xfibres.cc b/xfibres.cc index 812aad0bff581c5fca2cc44781f255ca91276623..3ee03f81dc817a5cb7e54d7454a4822d6b18c45e 100644 --- a/xfibres.cc +++ b/xfibres.cc @@ -1,6 +1,6 @@ /* Xfibres Diffusion Partial Volume Model - Tim Behrens - FMRIB Image Analysis Group + Tim Behrens, Saad Jbabdi - FMRIB Image Analysis Group Copyright (C) 2005 University of Oxford */ @@ -17,10 +17,12 @@ #include "utils/tracer_plus.h" #include "miscmaths/miscprob.h" #include "miscmaths/miscmaths.h" +#include "miscmaths/nonlin.h" #include "newimage/newimageall.h" #include "stdlib.h" #include "fibre.h" #include "xfibresoptions.h" +#include "diffmodels.h" using namespace FIBRE; using namespace Xfibres; @@ -30,53 +32,9 @@ using namespace NEWIMAGE; using namespace MISCMATHS; - - -inline float min(float a,float b){ - return a<b ? a:b;} -inline float max(float a,float b){ - return a>b ? a:b;} -inline Matrix Anis() -{ - Matrix A(3,3); - A << 1 << 0 << 0 - << 0 << 0 << 0 - << 0 << 0 << 0; - return A; -} - -inline Matrix Is() -{ - Matrix I(3,3); - I << 1 << 0 << 0 - << 0 << 1 << 0 - << 0 << 0 << 1; - return I; -} - -inline ColumnVector Cross(const ColumnVector& A,const ColumnVector& B) -{ - ColumnVector res(3); - res << A(2)*B(3)-A(3)*B(2) - << A(3)*B(1)-A(1)*B(3) - << A(1)*B(2)-B(1)*A(2); - return res; -} - -inline Matrix Cross(const Matrix& A,const Matrix& B) -{ - Matrix res(3,1); - res << A(2,1)*B(3,1)-A(3,1)*B(2,1) - << A(3,1)*B(1,1)-A(1,1)*B(3,1) - << A(1,1)*B(2,1)-B(1,1)*A(2,1); - return res; -} - -float mod(float a, float b){ - while(a>b){a=a-b;} - while(a<0){a=a+b;} - return a; -} +//////////////////////////////////////////////// +// Some USEFUL FUNCTIONS +//////////////////////////////////////////////// Matrix form_Amat(const Matrix& r,const Matrix& b) @@ -111,9 +69,15 @@ inline SymmetricMatrix vec2tens(ColumnVector& Vec){ +//////////////////////////////////////////// +// MCMC SAMPLE STORAGE +//////////////////////////////////////////// + + class Samples{ xfibresOptions& opts; Matrix m_dsamples; + Matrix m_d_stdsamples; Matrix m_S0samples; Matrix m_lik_energy; @@ -129,12 +93,14 @@ class Samples{ //for storing means RowVector m_mean_dsamples; + RowVector m_mean_d_stdsamples; RowVector m_mean_S0samples; vector<Matrix> m_dyadic_vectors; vector<RowVector> m_mean_fsamples; vector<RowVector> m_mean_lamsamples; float m_sum_d; + float m_sum_d_std; float m_sum_S0; vector<SymmetricMatrix> m_dyad; vector<float> m_sum_f; @@ -185,6 +151,16 @@ public: tmpvecs=0; m_sum_d=0; m_sum_S0=0; + + if(opts.modelnum.value()==2){ + m_d_stdsamples.ReSize(nsamples,nvoxels); + m_d_stdsamples=0; + m_mean_d_stdsamples.ReSize(nvoxels); + m_mean_d_stdsamples=0; + m_sum_d_std=0; + } + + SymmetricMatrix tmpdyad(3); tmpdyad=0; m_nsamps=nsamples; @@ -211,6 +187,10 @@ public: void record(Multifibre& mfib, int vox, int samp){ m_dsamples(samp,vox)=mfib.get_d(); m_sum_d+=mfib.get_d(); + if(opts.modelnum.value()==2){ + m_d_stdsamples(samp,vox)=mfib.get_d_std(); + m_sum_d_std+=mfib.get_d_std(); + } m_S0samples(samp,vox)=mfib.get_S0(); m_sum_S0+=mfib.get_S0(); m_lik_energy(samp,vox)=mfib.get_likelihood_energy(); @@ -237,10 +217,14 @@ public: void finish_voxel(int vox){ m_mean_dsamples(vox)=m_sum_d/m_nsamps; + if(opts.modelnum.value()==2) + m_mean_d_stdsamples(vox)=m_sum_d_std/m_nsamps; m_mean_S0samples(vox)=m_sum_S0/m_nsamps; m_sum_d=0; m_sum_S0=0; + if(opts.modelnum.value()==2) + m_sum_d_std=0; DiagonalMatrix dyad_D; //eigenvalues Matrix dyad_V; //eigenvectors @@ -353,8 +337,24 @@ public: mean_fsamples_out.push_back(m_mean_fsamples[f]); Log& logger = LogSingleton::getInstance(); - tmp.setmatrix(m_mean_dsamples,mask); - save_volume4D(tmp,logger.appendDir("mean_dsamples")); + if(opts.modelnum.value()==1){ + tmp.setmatrix(m_mean_dsamples,mask); + save_volume4D(tmp,logger.appendDir("mean_dsamples")); + } + else if(opts.modelnum.value()==2){ + tmp.setmatrix(m_mean_dsamples,mask); + save_volume4D(tmp,logger.appendDir("mean_dsamples")); + tmp.setmatrix(m_mean_d_stdsamples,mask); + save_volume4D(tmp,logger.appendDir("mean_d_stdsamples")); + + tmp.setmatrix(m_dsamples,mask); + save_volume4D(tmp,logger.appendDir("dsamples")); + tmp.setmatrix(m_d_stdsamples,mask); + save_volume4D(tmp,logger.appendDir("d_stdsamples")); + + + +} tmp.setmatrix(m_mean_S0samples,mask); save_volume4D(tmp,logger.appendDir("mean_S0samples")); //tmp.setmatrix(m_lik_energy,mask); @@ -423,10 +423,9 @@ public: }; - - - - +//////////////////////////////////////////// +// MCMC HANDLING +//////////////////////////////////////////// @@ -439,89 +438,54 @@ class xfibresVoxelManager{ const ColumnVector m_data; const ColumnVector& m_alpha; const ColumnVector& m_beta; + const Matrix& m_bvecs; const Matrix& m_bvals; Multifibre m_multifibre; public: xfibresVoxelManager(const ColumnVector& data,const ColumnVector& alpha, - const ColumnVector& beta, const Matrix& b, + const ColumnVector& beta, const Matrix& r,const Matrix& b, Samples& samples,int voxelnumber): opts(xfibresOptions::getInstance()), m_samples(samples),m_voxelnumber(voxelnumber),m_data(data), - m_alpha(alpha), m_beta(beta), m_bvals(b), - m_multifibre(m_data,m_alpha,m_beta,m_bvals,opts.nfibres.value(),opts.fudge.value()){ } + m_alpha(alpha), m_beta(beta), m_bvecs(r), m_bvals(b), + m_multifibre(m_data,m_alpha,m_beta,m_bvals,opts.nfibres.value(),opts.fudge.value(),opts.modelnum.value()){ } void initialise(const Matrix& Amat){ - if(!opts.localinit.value()){ - if(!m_samples.neighbour_initialise(m_voxelnumber,m_multifibre)){ - initialise_tensor(Amat); - } - }else{ - initialise_tensor(Amat); + if(opts.nonlin.value()) + initialise_nonlin(); + else{ + if(!opts.localinit.value()) + if(!m_samples.neighbour_initialise(m_voxelnumber,m_multifibre)) + initialise_tensor(Amat); + else + initialise_tensor(Amat); } m_multifibre.initialise_energies(); m_multifibre.initialise_props(); } + void initialise_tensor(const Matrix& Amat){ - //initialising - ColumnVector logS(m_data.Nrows()),tmp(m_data.Nrows()),Dvec(7),dir(3); - SymmetricMatrix tens; - DiagonalMatrix Dd; - Matrix Vd; - float mDd,fsquared; - float th,ph,f,D,S0; - for ( int i = 1; i <= logS.Nrows(); i++) - { - if(m_data(i)>0){ - logS(i)=log(m_data(i)); - } - else{ - logS(i)=0; - } - } - - Dvec = -pinv(Amat)*logS; - - if( Dvec(7) > -maxlogfloat ){ - S0=exp(-Dvec(7)); + DTI dti(m_data,Amat); + dti.fit(); + + float D = dti.get_md(); + if(opts.modelnum.value()==1){ + if(D<=0) D=2e-3; + m_multifibre.set_d(D); } - else{ - S0=m_data.MaximumAbsoluteValue(); + if(opts.modelnum.value()==2){ + D=D*2; //Will significantly underestimate D using mono-exponential tensor model, so initialise with 2*D; + if(D<=0) D=2e-3; + m_multifibre.set_d_std(D);//initialise with assumption that std=mean. + m_multifibre.set_d(D); } - - for ( int i = 1; i <= logS.Nrows(); i++) - { - if(S0<m_data.Sum()/m_data.Nrows()){ S0=m_data.MaximumAbsoluteValue(); } - logS(i)=(m_data(i)/S0)>0.01 ? log(m_data(i)):log(0.01*S0); - } + m_multifibre.set_S0(dti.get_s0()); - Dvec = -pinv(Amat)*logS; - S0=exp(-Dvec(7)); - - if(S0<m_data.Sum()/m_data.Nrows()){ S0=m_data.Sum()/m_data.Nrows(); } - tens = vec2tens(Dvec); - EigenValues(tens,Dd,Vd); - mDd = Dd.Sum()/Dd.Nrows(); - int maxind = Dd(1) > Dd(2) ? 1:2; //finding maximum eigenvalue - maxind = Dd(maxind) > Dd(3) ? maxind:3; - dir << Vd(1,maxind) << Vd(2,maxind) << Vd(3,maxind); - cart2sph(dir,th,ph); - th= mod(th,M_PI); - ph= mod(ph,2*M_PI); - D = Dd(maxind); - - float numer=1.5*((Dd(1)-mDd)*(Dd(1)-mDd)+(Dd(2)-mDd)*(Dd(2)-mDd)+(Dd(3)-mDd)*(Dd(3)-mDd)); - float denom=(Dd(1)*Dd(1)+Dd(2)*Dd(2)+Dd(3)*Dd(3)); - if(denom>0) fsquared=numer/denom; - else fsquared=0; - if(fsquared>0){f=sqrt(fsquared);} - else{f=0;} - if(f>=0.95) f=0.95; - if(f<=0.001) f=0.001; - if(D<=0) D=2e-3; - m_multifibre.set_d(D); - m_multifibre.set_S0(S0); + float th,ph,f; + cart2sph(dti.get_v1(),th,ph); + f = dti.get_fa(); if(opts.nfibres.value()>0){ // m_multifibre.addfibre(th,ph,f,1,false);//no a.r.d. on first fibre m_multifibre.addfibre(th,ph,f,1,opts.all_ard.value());//if all_ard, then turn ard on here (SJ) @@ -535,6 +499,77 @@ class xfibresVoxelManager{ } + void initialise_nonlin(){ + ////////////////////////////////////////////////////// + // where using mono-exponential model + if(opts.modelnum.value()==1){ + PVM_single pvm(m_data,m_bvecs,m_bvals,opts.nfibres.value()); + pvm.fit(); // this will give th,ph,f in the correct order + + m_multifibre.set_S0(pvm.get_s0()); + if(pvm.get_d()>=0) + m_multifibre.set_d(pvm.get_d()); + else + m_multifibre.set_d(2e-3); + + ColumnVector pvmf,pvmth,pvmph; + pvmf = pvm.get_f(); + pvmth = pvm.get_th(); + pvmph = pvm.get_ph(); + + if(opts.nfibres.value()>0){ + m_multifibre.addfibre(pvmth(1), + pvmph(1), + pvmf(1), + 1,opts.all_ard.value());//if all_ard, then turn ard on here (SJ) + for(int i=2; i<=opts.nfibres.value();i++){ + m_multifibre.addfibre(pvmth(i), + pvmph(i), + pvmf(i), + 1,!opts.no_ard.value()); + } + } + } + else{ + ////////////////////////////////////////////////////// + // model 2 : non-mono-exponential + PVM_multi pvm(m_data,m_bvecs,m_bvals,opts.nfibres.value()); + pvm.fit(); + + m_multifibre.set_S0(pvm.get_s0()); + if(pvm.get_d()>=0) + m_multifibre.set_d(pvm.get_d()); + else + m_multifibre.set_d(2e-3); + if(pvm.get_d_std()>=0) + m_multifibre.set_d_std(pvm.get_d()); + else + m_multifibre.set_d(2e-3); + + ColumnVector pvmf,pvmth,pvmph; + pvmf = pvm.get_f(); + pvmth = pvm.get_th(); + pvmph = pvm.get_ph(); + + if(opts.nfibres.value()>0){ + m_multifibre.addfibre(pvmth(1), + pvmph(1), + pvmf(1), + 1,opts.all_ard.value());//if all_ard, then turn ard on here (SJ) + for(int i=2; i<=opts.nfibres.value();i++){ + m_multifibre.addfibre(pvmth(i), + pvmph(i), + pvmf(i), + 1,!opts.no_ard.value()); + } + + } + } + + } + + + void runmcmc(){ int count=0, recordcount=0,sample=1;//sample will index a newmat matrix for( int i =0;i<opts.nburn.value();i++){ @@ -575,6 +610,11 @@ class xfibresVoxelManager{ }; + + +//////////////////////////////////////////// +// MAIN +//////////////////////////////////////////// int main(int argc, char *argv[]) { @@ -619,7 +659,7 @@ int main(int argc, char *argv[]) for(int vox=1;vox<=datam.Ncols();vox++){ cout <<vox<<"/"<<datam.Ncols()<<endl; - xfibresVoxelManager vm(datam.Column(vox),alpha,beta,bvals,samples,vox); + xfibresVoxelManager vm(datam.Column(vox),alpha,beta,bvecs,bvals,samples,vox); vm.initialise(Amat); vm.runmcmc(); }