From a302757fa14d5c6a06ca381170c5f09e30e2896a Mon Sep 17 00:00:00 2001
From: Saad Jbabdi <saad@fmrib.ox.ac.uk>
Date: Mon, 1 Mar 2010 17:19:19 +0000
Subject: [PATCH] fixes to model2 nonlinear fit

---
 diffmodels.cc | 60 ++++++++++++++++++++++++---------------------------
 diffmodels.h  |  4 ++++
 2 files changed, 32 insertions(+), 32 deletions(-)

diff --git a/diffmodels.cc b/diffmodels.cc
index 670c888..cb9988a 100644
--- a/diffmodels.cc
+++ b/diffmodels.cc
@@ -721,16 +721,12 @@ boost::shared_ptr<BFMatrix> PVM_single::hess(const NEWMAT::ColumnVector& p,boost
 
 void PVM_multi::fit(){
 
-  // initialise with a tensor
-  DTI dti(Y,bvecs,bvals);
-  dti.linfit();
-
   // initialise with simple pvm
   PVM_single pvm1(Y,bvecs,bvals,nfib);
   pvm1.fit();
 
   float _a,_b;
-  _a = 1; // start with d=d_std
+  _a = 1.0; // start with d=d_std
   _b = pvm1.get_d();
 
   ColumnVector start(nparams);
@@ -753,11 +749,10 @@ void PVM_multi::fit(){
   ColumnVector final_par(nparams);
   final_par = lmpar.Par();
 
-
   // finalise parameters
   m_s0     = final_par(1);
   m_d      = std::abs(final_par(2)*final_par(3));
-  m_d_std  = std::sqrt(std::abs(final_par(2)))*std::abs(final_par(3));
+  m_d_std  = std::sqrt(std::abs(final_par(2)*final_par(3)*final_par(3)));
   for(int i=4,k=1;k<=nfib;i+=3,k++){
     m_f(k)  = x2f(final_par(i));
     m_th(k) = final_par(i+1);
@@ -765,6 +760,7 @@ void PVM_multi::fit(){
   }
   sort();
   fix_fsum();
+
 }
 void PVM_multi::sort(){
   vector< pair<float,int> > fvals(nfib);
@@ -793,7 +789,7 @@ ReturnMatrix PVM_multi::get_prediction()const{
   ColumnVector p(nparams);
   p(1) = m_s0;
   p(2) = m_d*m_d/m_d_std/m_d_std;
-  p(3) = m_d_std*m_d_std/m_d;
+  p(3) = m_d_std*m_d_std/m_d; // =1/beta
   for(int k=1;k<=nfib;k++){
     int kk = 4+3*(k-1);
     p(kk)   = f2x(m_f(k));
@@ -858,7 +854,7 @@ double PVM_multi::cf(const NEWMAT::ColumnVector& p)const{
     for(int k=1;k<=nfib;k++){
       err += fs(k)*anisoterm(i,_a,_b,x.Row(k).t());
     }
-    err = (p(1)*((1-sumf)*isoterm(i,_a,_b)+err) - Y(i)); 
+    err = (std::abs(p(1))*((1-sumf)*isoterm(i,_a,_b)+err) - Y(i)); 
     cfv += err*err; 
   }  
   //OUT(cfv);
@@ -898,22 +894,22 @@ NEWMAT::ReturnMatrix PVM_multi::grad(const NEWMAT::ColumnVector& p)const{
       sig += fs(k)*anisoterm(i,_a,_b,xx);
       // other stuff for derivatives
       // alpha
-      J(i,2) += (p(2)>0?1.0:-1.0)*p(1)*fs(k)*anisoterm_a(i,_a,_b,xx);
+      J(i,2) += (p(2)>0?1.0:-1.0)*std::abs(p(1))*fs(k)*anisoterm_a(i,_a,_b,xx);
       // beta
-      J(i,3) += (p(3)>0?1.0:-1.0)*p(1)*fs(k)*anisoterm_b(i,_a,_b,xx) * (-p(3)*p(3)); // change of variable beta=1/beta
+      J(i,3) += (p(3)>0?1.0:-1.0)*std::abs(p(1))*fs(k)*anisoterm_b(i,_a,_b,xx);
       // f
-      J(i,kk) = p(1)*(anisoterm(i,_a,_b,xx)-isoterm(i,_a,_b)) * two_pi*sign(p(kk))*1/(1+p(kk)*p(kk));
+      J(i,kk) = std::abs(p(1))*(anisoterm(i,_a,_b,xx)-isoterm(i,_a,_b)) * two_pi*sign(p(kk))*1/(1+p(kk)*p(kk));
       // th
-      J(i,kk+1) = p(1)*fs(k)*anisoterm_th(i,_a,_b,xx,p(kk+1),p(kk+2));
+      J(i,kk+1) = std::abs(p(1))*fs(k)*anisoterm_th(i,_a,_b,xx,p(kk+1),p(kk+2));
       // ph
-      J(i,kk+2) = p(1)*fs(k)*anisoterm_ph(i,_a,_b,xx,p(kk+1),p(kk+2));
+      J(i,kk+2) = std::abs(p(1))*fs(k)*anisoterm_ph(i,_a,_b,xx,p(kk+1),p(kk+2));
     }
-    sig = p(1)*((1-sumf)*isoterm(i,_a,_b)+sig);
+    sig = std::abs(p(1))*((1-sumf)*isoterm(i,_a,_b)+sig);
     diff(i) = sig - Y(i);
     // other stuff for derivatives
-    J(i,1) = sig/p(1);
-    J(i,2) += (p(2)>0?1.0:-1.0)*p(1)*(1-sumf)*isoterm_a(i,_a,_b);
-    J(i,3) += (p(3)>0?1.0:-1.0)*p(1)*(1-sumf)*isoterm_b(i,_a,_b) * (-p(3)*p(3));
+    J(i,1) = (p(1)>0?1.0:-1.0)*sig/p(1);
+    J(i,2) += (p(2)>0?1.0:-1.0)*std::abs(p(1))*(1-sumf)*isoterm_a(i,_a,_b);
+    J(i,3) += (p(3)>0?1.0:-1.0)*std::abs(p(1))*(1-sumf)*isoterm_b(i,_a,_b);
   }
   
   gradv = 2*J.t()*diff;
@@ -961,22 +957,22 @@ boost::shared_ptr<BFMatrix> PVM_multi::hess(const NEWMAT::ColumnVector& p,boost:
       // change of variable
       float cov = two_pi*sign(p(kk))*1/(1+p(kk)*p(kk));
       // alpha
-      J(i,2) += (p(2)>0?1.0:-1.0)*p(1)*fs(k)*anisoterm_a(i,_a,_b,xx);
+      J(i,2) += (p(2)>0?1.0:-1.0)*std::abs(p(1))*fs(k)*anisoterm_a(i,_a,_b,xx);
       // beta
-      J(i,3) += (p(3)>0?1.0:-1.0)*p(1)*fs(k)*anisoterm_b(i,_a,_b,xx) * (-p(3)*p(3));
+      J(i,3) += (p(3)>0?1.0:-1.0)*std::abs(p(1))*fs(k)*anisoterm_b(i,_a,_b,xx);
       // f
-      J(i,kk) = p(1)*(anisoterm(i,_a,_b,xx)-isoterm(i,_a,_b)) * cov;
+      J(i,kk) = std::abs(p(1))*(anisoterm(i,_a,_b,xx)-isoterm(i,_a,_b)) * cov;
       // th
-      J(i,kk+1) = p(1)*fs(k)*anisoterm_th(i,_a,_b,xx,p(kk+1),p(kk+2));
+      J(i,kk+1) = std::abs(p(1))*fs(k)*anisoterm_th(i,_a,_b,xx,p(kk+1),p(kk+2));
       // ph
-      J(i,kk+2) = p(1)*fs(k)*anisoterm_ph(i,_a,_b,xx,p(kk+1),p(kk+2));
+      J(i,kk+2) = std::abs(p(1))*fs(k)*anisoterm_ph(i,_a,_b,xx,p(kk+1),p(kk+2));
     }
-    sig = p(1)*((1-sumf)*isoterm(i,_a,_b)+sig);
+    sig = std::abs(p(1))*((1-sumf)*isoterm(i,_a,_b)+sig);
     diff(i) = sig - Y(i);
     // other stuff for derivatives
-    J(i,1) = sig/p(1);
-    J(i,2) += (p(2)>0?1.0:-1.0)*p(1)*(1-sumf)*isoterm_a(i,_a,_b);
-    J(i,3) += (p(3)>0?1.0:-1.0)*p(1)*(1-sumf)*isoterm_b(i,_a,_b) * (-p(3)*p(3));
+    J(i,1) = (p(1)>0?1.0:-1.0)*sig/p(1);
+    J(i,2) += (p(2)>0?1.0:-1.0)*std::abs(p(1))*(1-sumf)*isoterm_a(i,_a,_b);
+    J(i,3) += (p(3)>0?1.0:-1.0)*std::abs(p(1))*(1-sumf)*isoterm_b(i,_a,_b);
 
   }
   
@@ -1084,7 +1080,7 @@ float PVM_multi::isoterm(const int& pt,const float& _a,const float& _b)const{
 }
 float PVM_multi::anisoterm(const int& pt,const float& _a,const float& _b,const ColumnVector& x)const{
   float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
-  return(std::exp(-_a*std::log(1+bvals(1,pt)*_b*dp*dp)));
+  return(std::exp(-_a*std::log(1+bvals(1,pt)*_b*(dp*dp))));
 }
 // 1st order derivatives
 float PVM_multi::isoterm_a(const int& pt,const float& _a,const float& _b)const{
@@ -1095,20 +1091,20 @@ float PVM_multi::isoterm_b(const int& pt,const float& _a,const float& _b)const{
 }
 float PVM_multi::anisoterm_a(const int& pt,const float& _a,const float& _b,const ColumnVector& x)const{
   float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
-  return(-std::log(1+bvals(1,pt)*dp*dp*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*dp*dp*_b)));
+  return(-std::log(1+bvals(1,pt)*(dp*dp)*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*(dp*dp)*_b)));
 }
 float PVM_multi::anisoterm_b(const int& pt,const float& _a,const float& _b,const ColumnVector& x)const{
   float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
-  return(-_a*bvals(1,pt)*dp*dp/(1+bvals(1,pt)*dp*dp*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*dp*dp*_b)));
+  return(-_a*bvals(1,pt)*(dp*dp)/(1+bvals(1,pt)*(dp*dp)*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*(dp*dp)*_b)));
 }
 float PVM_multi::anisoterm_th(const int& pt,const float& _a,const float& _b,const ColumnVector& x,const float& _th,const float& _ph)const{
   float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
   float dp1 = cos(_th)*(bvecs(1,pt)*cos(_ph) + bvecs(2,pt)*sin(_ph)) - bvecs(3,pt)*sin(_th);
-  return(-_a*_b*bvals(1,pt)/(1+bvals(1,pt)*dp*dp*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*dp*dp*_b))*2*dp*dp1);
+  return(-_a*_b*bvals(1,pt)/(1+bvals(1,pt)*(dp*dp)*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*(dp*dp)*_b))*2*dp*dp1);
 }
 float PVM_multi::anisoterm_ph(const int& pt,const float& _a,const float& _b,const ColumnVector& x,const float& _th,const float& _ph)const{
   float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
   float dp1 = sin(_th)*(-bvecs(1,pt)*sin(_ph) + bvecs(2,pt)*cos(_ph));
-  return(-_a*_b*bvals(1,pt)/(1+bvals(1,pt)*dp*dp*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*dp*dp*_b))*2*dp*dp1);
+  return(-_a*_b*bvals(1,pt)/(1+bvals(1,pt)*(dp*dp)*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*(dp*dp)*_b))*2*dp*dp1);
 }
 
diff --git a/diffmodels.h b/diffmodels.h
index 203966f..5e8ef86 100644
--- a/diffmodels.h
+++ b/diffmodels.h
@@ -359,8 +359,12 @@ public:
     cout << "D_STD :" << m_d_std << endl;
     for(int i=1;i<=nfib;i++){
       cout << "F" << i << "    :" << m_f(i) << endl;
+      ColumnVector x(3);
+      x << sin(m_th(i))*cos(m_ph(i)) << sin(m_th(i))*sin(m_ph(i)) << cos(m_th(i));
+      if(x(3)<0)x=-x;
       cout << "TH" << i << "   :" << m_th(i) << endl; 
       cout << "PH" << i << "   :" << m_ph(i) << endl; 
+      cout << "DIR" << i << "   : " << x(1) << " " << x(2) << " " << x(3) << endl;
     }
   }
   void print(const ColumnVector& p)const{
-- 
GitLab