Skip to content
Snippets Groups Projects
pvmfit.cc 14.32 KiB
/*  Copyright (C) 2009 University of Oxford  */



/*  CCOPYRIGHT  */

#include <iostream>
#include <cmath>
#include "miscmaths/miscmaths.h"
#include "miscmaths/nonlin.h"
#include "newmat.h"
#include "pvmfitOptions.h"
#include "newimage/newimageall.h"

const float two_pi=0.636619772;

using namespace std;
using namespace NEWMAT;
using namespace MISCMATHS;
using namespace PVMFIT;
using namespace NEWIMAGE;

inline SymmetricMatrix vec2tens(ColumnVector& Vec){
  SymmetricMatrix tens(3);
  tens(1,1)=Vec(1);
  tens(2,1)=Vec(2);
  tens(3,1)=Vec(3);
  tens(2,2)=Vec(4);
  tens(3,2)=Vec(5);
  tens(3,3)=Vec(6);
  return tens;
}
Matrix form_Amat(const Matrix& r,const Matrix& b)
{
  Matrix A(r.Ncols(),7);
  Matrix tmpvec(3,1), tmpmat;
  
  for( int i = 1; i <= r.Ncols(); i++){
    tmpvec << r(1,i) << r(2,i) << r(3,i);
    tmpmat = tmpvec*tmpvec.t()*b(1,i);
    A(i,1) = tmpmat(1,1);
    A(i,2) = 2*tmpmat(1,2);
    A(i,3) = 2*tmpmat(1,3);
    A(i,4) = tmpmat(2,2);
    A(i,5) = 2*tmpmat(2,3);
    A(i,6) = tmpmat(3,3);
    A(i,7) = 1;
  }
  return A;
}
// dtifit
class DTI {
public: 

  DTI(const ColumnVector& iY,
      const Matrix& iAmat):Y(iY),pinvAmat(iAmat){
  
    npts = Y.Nrows();
    v1.ReSize(3);
    v2.ReSize(3);
    v3.ReSize(3);

  }
  ~DTI(){}
  void fit();
  float get_fa()const{return fa;}
  float get_md()const{return md;}
  float get_s0()const{return s0;}
  float get_mo()const{return mo;}
  ColumnVector get_v1()const{return v1;}
  ColumnVector get_v2()const{return v2;}
  ColumnVector get_v3()const{return v3;}
  ColumnVector get_v(const int& i)const{if(i==1)return v1;else if(i==2)return v2;else return v3;}
  void print();

private:
  const ColumnVector& Y;
  const Matrix& pinvAmat;
  int npts;
  ColumnVector v1,v2,v3;
  float l1,l2,l3;
  float fa,s0,md,mo;

};
void DTI::fit(){
  ColumnVector logS(npts);
  ColumnVector Dvec(7);
  SymmetricMatrix tens;
  Matrix Vd;
  DiagonalMatrix Dd(3);

  for (int i=1;i<=npts; i++){
    if(Y(i)>0)
      logS(i)=log(Y(i));
    else
      logS(i)=0;
  }

  Dvec = -pinvAmat*logS;

  if(Dvec(7)>-23)
    s0=exp(-Dvec(7));
  else
    s0=Y.MaximumAbsoluteValue();
  
  for (int i=1;i<=Y.Nrows();i++){
    if(s0<Y.Sum()/Y.Nrows()){ s0=Y.MaximumAbsoluteValue();  }
    logS(i)=(Y(i)/s0)>0.01 ? log(Y(i)):log(0.01*s0);
  }
  Dvec = -pinvAmat*logS;
  s0=exp(-Dvec(7));
  
  if(s0<Y.Sum()/Y.Nrows()){ s0=Y.Sum()/Y.Nrows();  }
  tens = vec2tens(Dvec);
  
  EigenValues(tens,Dd,Vd);
  
  md = Dd.Sum()/Dd.Nrows();

  l1 = Dd(3,3);
  l2 = Dd(2,2);
  l3 = Dd(1,1);
  v1 = Vd.Column(3);
  v2 = Vd.Column(2);
  v3 = Vd.Column(1);

  float e1=l1-md, e2=l2-md, e3=l3-md;
  float n = (e1 + e2 - 2*e3)*(2*e1 - e2 - e3)*(e1 - 2*e2 + e3);
  float d = (e1*e1 + e2*e2 + e3*e3 - e1*e2 - e2*e3 - e1*e3);
  d = sqrt(MAX(0, d));
  d = 2*d*d*d;
  mo = MIN(MAX(d ? n/d : 0.0, -1),1);

  float numer=1.5*((l1-md)*(l1-md)+(l2-md)*(l2-md)+(l3-md)*(l3-md));
  float denom=(l1*l1+l2*l2+l3*l3);
  if(denom>0) fa=numer/denom;
  else fa=0;
  if(fa>0){fa=sqrt(fa);}
  else{fa=0;}

}
void DTI::print(){
  cout << "DTI FIT RESULTS " << endl;
  cout << "S0   :" << s0 << endl;
  cout << "MD   :" << md << endl;
  cout << "FA   :" << fa << endl;; 
  cout << "MO   :" << mo << endl;; 
}

// nonlinear optimisation

// nonlinear class for despot1_hifi
class PVMNonlinCF : public NonlinCF {
public:
  PVMNonlinCF(const ColumnVector& iY,
	      const ColumnVector& ibvals,
	      const ColumnVector& isina,const ColumnVector& icosa,const ColumnVector& ibeta,
	      const int& nfibres):Y(iY),bvals(ibvals),sinalpha(isina),cosalpha(icosa),beta(ibeta){

    npts    = Y.Nrows();
    nparams = nfibres*3 + 2;
    nfib    = nfibres;

 //    OUT(Y.t());
//     OUT(npts);
//     OUT(nparams);
//     OUT(bvals.t());
//     OUT(sinalpha.t());
//     OUT(beta.t());

  }
  ~PVMNonlinCF(){}
  
  NEWMAT::ReturnMatrix grad(const NEWMAT::ColumnVector& p)const;
  boost::shared_ptr<BFMatrix> hess(const NEWMAT::ColumnVector&p,boost::shared_ptr<BFMatrix> iptr)const;
  double cf(const NEWMAT::ColumnVector& p)const;
  NEWMAT::ReturnMatrix forwardModel(const NEWMAT::ColumnVector& p)const;

private:
  const ColumnVector& Y;
  const ColumnVector& bvals;
  const ColumnVector& sinalpha;
  const ColumnVector& cosalpha;
  const ColumnVector& beta;
  
  float npts;
  int   nparams;
  int   nfib;
};
NEWMAT::ReturnMatrix PVMNonlinCF::forwardModel(const NEWMAT::ColumnVector& p)const{
  ColumnVector ret(npts);
  
  double val;
  float angtmp;
  ColumnVector f(nfib);

  float sumf=0;
  for(int i=3,j=1;i<=p.Nrows();i+=3,j++){
    f(j) = abs(two_pi*atan(p(i)));  
    sumf  += f(j); 
  }      
  for(int i=1;i<=Y.Nrows();i++){
    val = 0.0;
    for(int k=3;k<=p.Nrows();k+=3){
      angtmp = cos(p(k+2)-beta(i))*sinalpha(i)*sin(p(k+1)) + cosalpha(i)*cos(p(k+1));
      
      val += f(k/3)*exp(-bvals(i)*p(2)*angtmp*angtmp);
    }
    ret(i) = (p(1)*((1-sumf)*exp(-bvals(i)*p(2))+val)); 

  }

  ret.Release();
  return(ret);
}

double PVMNonlinCF::cf(const NEWMAT::ColumnVector& p)const{
  // p(1) = S0
  // p(2) = d
  // p(3) = f1
  // p(4) = th1
  // p(5) = ph1
  // etc.
  
  //cout << "CF" << endl;
  //OUT(p.t());

  double cfv = 0.0;
  double err;
  float angtmp;
  ColumnVector f(nfib);

  float sumf=0;
  for(int i=3,j=1;i<=p.Nrows();i+=3,j++){
    f(j)  = abs(two_pi*atan(p(i)));
    sumf += f(j);
  }

  for(int i=1;i<=Y.Nrows();i++){
    err = 0.0;
    for(int k=3;k<=p.Nrows();k+=3){
      angtmp = cos(p(k+2)-beta(i))*sinalpha(i)*sin(p(k+1)) + cosalpha(i)*cos(p(k+1));
      //err += p(k)*exp(-bvals(i)*p(2)*angtmp*angtmp);
      err += f(k/3)*exp(-bvals(i)*p(2)*angtmp*angtmp);
    }
    err = (p(1)*((1-sumf)*exp(-bvals(i)*p(2))+err) - Y(i)); 

    cfv += err*err; 
  }
  //OUT(cfv);
  return(cfv);
}
NEWMAT::ReturnMatrix PVMNonlinCF::grad(const NEWMAT::ColumnVector& p)const{
  //cout << "gradv" << endl;
  //OUT(p.t());

  NEWMAT::ColumnVector gradv(p.Nrows());
  gradv = 0.0;
  double dval1;
  double dval2;
  float angtmp,tmpval;
  ColumnVector f(nfib);

  float sumf=0;
  for(int i=3,j=1;i<=p.Nrows();i+=3,j++){
    f(j) = abs(two_pi*atan(p(i)));
    sumf += f(j);
  }
      
  for(int i=1;i<=Y.Nrows();i++){
    // calculate difference between signal and data
    dval1 = 0.0;
    dval2 = 0.0;
    for(int k=3;k<=p.Nrows();k+=3){
      angtmp = cos(p(k+2)-beta(i))*sinalpha(i)*sin(p(k+1)) + cosalpha(i)*cos(p(k+1));
      angtmp *= angtmp;
      //      tmpval = p(k)*exp(-bvals(i)*p(2)*angtmp);
      tmpval = f(k/3)*exp(-bvals(i)*p(2)*angtmp);
      dval1 += tmpval;
      dval2 += -bvals(i)*angtmp*tmpval;
    }
    dval1 = (p(1)*((1-sumf)*exp(-bvals(i)*p(2))+dval1) - Y(i)); 
    dval2 = (p(1)*(-bvals(i)*(1-sumf)*exp(-bvals(i)*p(2))+dval2)); 
    
    gradv(1) += 2 * dval1 * (dval1+Y(i))/p(1);
    gradv(2) += 2 * dval1 * dval2;

    for(int k=3;k<=p.Nrows();k+=3){
      angtmp = cos(p(k+2)-beta(i))*sinalpha(i)*sin(p(k+1)) + cosalpha(i)*cos(p(k+1));
      tmpval = f(k/3)*exp(-bvals(i)*p(2)*angtmp*angtmp);

      gradv(k)   += 2 * dval1 * p(1) * (-exp(-bvals(i)*p(2)) + tmpval) / f(k/3) * sign(p(k)) / (1+p(k)*p(k));
      gradv(k+1) += 2 * dval1 * p(1) * tmpval * (-bvals(i)*p(2)*2*angtmp*(cos(p(k+2)-beta(i))*sinalpha(i)*cos(p(k+1)) - cosalpha(i)*sin(p(k+1))));
      gradv(k+2) += 2 * dval1 * p(1) * tmpval * (bvals(i)*p(2)*2*angtmp*(sin(p(k+2)-beta(i))*sinalpha(i)*sin(p(k+1))));
      
    }
  }
  
  //OUT(gradv.t());

  gradv.Release();
  return(gradv);
}
// this uses Gauss-Newton approximation
boost::shared_ptr<BFMatrix> PVMNonlinCF::hess(const NEWMAT::ColumnVector& p,boost::shared_ptr<BFMatrix> iptr)const{
  //cout << "hessian" << endl;
  //OUT(p.t());

  boost::shared_ptr<BFMatrix>   hessm;

  if (iptr && iptr->Nrows()==(unsigned int)p.Nrows() && iptr->Ncols()==(unsigned int)p.Nrows()) hessm = iptr;
  else hessm = boost::shared_ptr<BFMatrix>(new FullBFMatrix(p.Nrows(),p.Nrows()));

  double dval1;
  double dval2;
  float angtmp,tmpval;
  ColumnVector f(nfib);
  float sumf=0;
  for(int i=3,j=1;i<=p.Nrows();i+=3,j++){
    f(j) = abs(two_pi*atan(p(i)));
    sumf += f(j);
  }

  Matrix J(Y.Nrows(),nparams);
  for(int i=1;i<=Y.Nrows();i++){

    dval1 = 0.0;
    dval2 = 0.0;
    for(int k=3;k<=p.Nrows();k+=3){
      angtmp = cos(p(k+2)-beta(i))*sinalpha(i)*sin(p(k+1)) + cosalpha(i)*cos(p(k+1));
      angtmp *= angtmp;
      //      tmpval = p(k)*exp(-bvals(i)*p(2)*angtmp);
      tmpval = f(k/3)*exp(-bvals(i)*p(2)*angtmp);
      dval1 += tmpval;
      dval2 += -bvals(i)*angtmp*tmpval;
    }
    dval1 = (p(1)*((1-sumf)*exp(-bvals(i)*p(2))+dval1) - Y(i)); 
    dval2 = (p(1)*(-bvals(i)*(1-sumf)*exp(-bvals(i)*p(2))+dval2)); 

    J(i,1) = (dval1+Y(i))/p(1);
    J(i,2) = dval2;

    for(int k=3;k<=p.Nrows();k+=3){
      angtmp = cos(p(k+2)-beta(i))*sinalpha(i)*sin(p(k+1)) + cosalpha(i)*cos(p(k+1));
      tmpval = f(k/3)*exp(-bvals(i)*p(2)*angtmp*angtmp);

      J(i,k)   = p(1) * (-exp(-bvals(i)*p(2)) + tmpval)/ f(k/3) * sign(p(k)) / (1+p(k)*p(k));;
      J(i,k+1) = p(1) * tmpval * (-bvals(i)*p(2)*2*angtmp*(cos(p(k+2)-beta(i))*sinalpha(i)*cos(p(k+1)) - cosalpha(i)*sin(p(k+1))));
      J(i,k+2) = p(1) * tmpval * (bvals(i)*p(2)*2*angtmp*(sin(p(k+2)-beta(i))*sinalpha(i)*sin(p(k+1)))); 
    }
  }
  
  for (int i=1; i<=p.Nrows(); i++){
    for (int j=i; j<=p.Nrows(); j++){
      dval1 = 0.0;
      for(int k=1;k<=J.Nrows();k++)
	dval1 += J(k,i)*J(k,j);
      hessm->Set(i,j,dval1);
    }
  }

  for (int j=1; j<=p.Nrows(); j++) {
    for (int i=j+1; i<=p.Nrows(); i++) {
      hessm->Set(i,j,hessm->Peek(j,i));
    }
  }

  //hessm->Print();

  return(hessm);
}



int main(int argc, char** argv)
{
  //parse command line
  pvmfitOptions& opts = pvmfitOptions::getInstance();
  int success=opts.parse_command_line(argc,argv);
  if(!success) return 1;
   if(opts.verbose.value()){
    cout<<"data file "<<opts.datafile.value()<<endl;
    cout<<"mask file "<<opts.maskfile.value()<<endl;
    cout<<"bvecs     "<<opts.bvecsfile.value()<<endl;
    cout<<"bvals     "<<opts.bvalsfile.value()<<endl;
  }
  
  // Set random seed:
  Matrix r = read_ascii_matrix(opts.bvecsfile.value());
  if(r.Nrows()>3) r=r.t();
  for(int i=1;i<=r.Ncols();i++){
    float tmpsum=sqrt(r(1,i)*r(1,i)+r(2,i)*r(2,i)+r(3,i)*r(3,i));
    if(tmpsum!=0){
      r(1,i)=r(1,i)/tmpsum;
      r(2,i)=r(2,i)/tmpsum;
      r(3,i)=r(3,i)/tmpsum;
    }  
  }
  Matrix b = read_ascii_matrix(opts.bvalsfile.value());
  if(b.Nrows()>1) b=b.t();

  ColumnVector alpha,beta,sinalpha,cosalpha;
  ColumnVector bvals;

  bvals = b.Row(1).t();
  cart2sph(r,alpha,beta);
  sinalpha.ReSize(alpha.Nrows());
  cosalpha.ReSize(alpha.Nrows());
  for(int i=1;i<=alpha.Nrows();i++){
    sinalpha(i) = sin(alpha(i));
    cosalpha(i) = cos(alpha(i));
  }

  // for dti
  Matrix Amat;
  Amat = form_Amat(r,b);
  Amat = pinv(Amat);

  volume4D<float> data;
  volume<int> mask;

  if(opts.verbose.value()) cout<<"reading data"<<endl;
  read_volume4D(data,opts.datafile.value());

  if(opts.verbose.value()) cout<<"reading mask"<<endl;
  read_volume(mask,opts.maskfile.value());

  if(opts.verbose.value()) cout<<"ok"<<endl;
  int minx=0;
  int maxx=mask.xsize();
  int miny=0;
  int maxy=mask.ysize();
  int minz=0;
  int maxz=mask.zsize();
  cout<<minx<<" "<<maxx<<" "<<miny<<" "<<maxy<<" "<<minz<<" "<<maxz<<endl;

  if(opts.verbose.value()) cout<<"setting up vols"<<endl;
  volume<float> S0(maxx-minx,maxy-miny,maxz-minz);
  volume<float> dvol(maxx-minx,maxy-miny,maxz-minz);
  volume<float> tmpvol(maxx-minx,maxy-miny,maxz-minz);
  volume4D<float> tmpvol4D(maxx-minx,maxy-miny,maxz-minz,3);

  vector< volume<float> > fvol,thvol,phvol;
  vector< volume4D<float> > dyads;

  if(opts.verbose.value()) cout<<"copying input properties to output volumes"<<endl;
  copybasicproperties(data[0],S0);
  copybasicproperties(data[0],dvol);

  tmpvol = 0;
  tmpvol4D = 0;
  for(int i=0;i<opts.nfibres.value();i++){
    fvol.push_back(tmpvol);
    thvol.push_back(tmpvol);
    phvol.push_back(tmpvol);
    dyads.push_back(tmpvol4D);
  }

  if(opts.verbose.value()) cout<<"zeroing output volumes"<<endl;
  S0=0;dvol=0;

  if(opts.verbose.value()) cout<<"ok"<<endl;

  
  ColumnVector S(bvals.Nrows());
  if(opts.verbose.value()) cout<<"starting the fits"<<endl;
  for(int k = minz; k < maxz; k++){
    cout<<k<<" slices processed"<<endl;
    for(int j=miny; j < maxy; j++){
      for(int i =minx; i< maxx; i++){
	if(mask(i,j,k)==0)continue;
	
	for(int t=0;t < data.tsize();t++)
	  S(t+1)=data(i,j,k,t);
	
	// initialisation ///////////////////////////////////
	DTI dti(S,Amat);
	dti.fit();

	//dti.print();

	float th,ph;
	cart2sph(dti.get_v1(),th,ph);
	ColumnVector start(2+3*opts.nfibres.value());
	start(1) = dti.get_s0();
	start(2) = (dti.get_md()>0?dti.get_md():0.001);
	start(3) = tan(dti.get_fa());
	start(4) = th;
	start(5) = ph;
	float sumf=abs(two_pi*atan(start(3)));
	for(int ff=2;ff<=opts.nfibres.value();ff++){
	  float denom=2;
	  do{
	    start(ff*3) = tan(abs(two_pi*atan(start(3))/(denom)));
	    denom *= 2;
	  }while(sumf>1);
	  sumf += abs(two_pi*atan(start(ff*3)));
	  cart2sph(dti.get_v(ff),th,ph);
	  start(ff*3+1) = th;
	  start(ff*3+2) = ph;
	}
	//////////////////////////////////////////////////////

	PVMNonlinCF pvm_cf(S,bvals,sinalpha,cosalpha,beta,opts.nfibres.value());
	ColumnVector final_par;

	//cout << i << " " << j << " " << k << endl;
	//OUT(start.t());

	if(dti.get_fa()<1){
	  NonlinParam  lmpar(start.Nrows(),NL_LM,start); // Levenberg-Marquardt
	  lmpar.SetStartingEstimate(start);
	
	  NonlinOut status = nonlin(lmpar,pvm_cf);
	  final_par = lmpar.Par();
	}
	else{
	  final_par = start;
	}

 	//OUT(final_par.t());

	
//  	OUT(pvm_cf.forwardModel(final_par).t());

	S0(i-minx,j-miny,k-minz)=final_par(1);
	dvol(i-minx,j-miny,k-minz)=final_par(2);
	
	for(int f=0;f<opts.nfibres.value();f++){
	  fvol[f](i-minx,j-miny,k-minz) = abs(2.0/M_PI*atan(final_par(3*(1+f))));
	  thvol[f](i-minx,j-miny,k-minz) = final_par(3*(1+f)+1);
	  phvol[f](i-minx,j-miny,k-minz) = final_par(3*(1+f)+2);
	}
	
      }
    }
  }
  
  
  if(opts.verbose.value())
    cout << "saving results" << endl;

  S0.setDisplayMaximumMinimum(S0.max(),0);
  save_volume(S0,opts.ofile.value()+"_S0");

  dvol.setDisplayMaximumMinimum(dvol.max(),0);
  save_volume(dvol,opts.ofile.value()+"_D");

  for(int f=1;f<=opts.nfibres.value();f++){
    fvol[f-1].setDisplayMaximumMinimum(1,0);
    save_volume(fvol[f-1],opts.ofile.value()+"_f"+num2str(f));
    thvol[f-1].setDisplayMaximumMinimum(thvol[f-1].max(),thvol[f-1].min());
    save_volume(thvol[f-1],opts.ofile.value()+"_th"+num2str(f));
    phvol[f-1].setDisplayMaximumMinimum(phvol[f-1].max(),phvol[f-1].min());
    save_volume(phvol[f-1],opts.ofile.value()+"_ph"+num2str(f));
  }

  return 0;
}