Skip to content
Snippets Groups Projects
Commit a302757f authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

fixes to model2 nonlinear fit

parent c9f4866c
No related branches found
No related tags found
No related merge requests found
......@@ -721,16 +721,12 @@ boost::shared_ptr<BFMatrix> PVM_single::hess(const NEWMAT::ColumnVector& p,boost
void PVM_multi::fit(){
// initialise with a tensor
DTI dti(Y,bvecs,bvals);
dti.linfit();
// initialise with simple pvm
PVM_single pvm1(Y,bvecs,bvals,nfib);
pvm1.fit();
float _a,_b;
_a = 1; // start with d=d_std
_a = 1.0; // start with d=d_std
_b = pvm1.get_d();
ColumnVector start(nparams);
......@@ -753,11 +749,10 @@ void PVM_multi::fit(){
ColumnVector final_par(nparams);
final_par = lmpar.Par();
// finalise parameters
m_s0 = final_par(1);
m_d = std::abs(final_par(2)*final_par(3));
m_d_std = std::sqrt(std::abs(final_par(2)))*std::abs(final_par(3));
m_d_std = std::sqrt(std::abs(final_par(2)*final_par(3)*final_par(3)));
for(int i=4,k=1;k<=nfib;i+=3,k++){
m_f(k) = x2f(final_par(i));
m_th(k) = final_par(i+1);
......@@ -765,6 +760,7 @@ void PVM_multi::fit(){
}
sort();
fix_fsum();
}
void PVM_multi::sort(){
vector< pair<float,int> > fvals(nfib);
......@@ -793,7 +789,7 @@ ReturnMatrix PVM_multi::get_prediction()const{
ColumnVector p(nparams);
p(1) = m_s0;
p(2) = m_d*m_d/m_d_std/m_d_std;
p(3) = m_d_std*m_d_std/m_d;
p(3) = m_d_std*m_d_std/m_d; // =1/beta
for(int k=1;k<=nfib;k++){
int kk = 4+3*(k-1);
p(kk) = f2x(m_f(k));
......@@ -858,7 +854,7 @@ double PVM_multi::cf(const NEWMAT::ColumnVector& p)const{
for(int k=1;k<=nfib;k++){
err += fs(k)*anisoterm(i,_a,_b,x.Row(k).t());
}
err = (p(1)*((1-sumf)*isoterm(i,_a,_b)+err) - Y(i));
err = (std::abs(p(1))*((1-sumf)*isoterm(i,_a,_b)+err) - Y(i));
cfv += err*err;
}
//OUT(cfv);
......@@ -898,22 +894,22 @@ NEWMAT::ReturnMatrix PVM_multi::grad(const NEWMAT::ColumnVector& p)const{
sig += fs(k)*anisoterm(i,_a,_b,xx);
// other stuff for derivatives
// alpha
J(i,2) += (p(2)>0?1.0:-1.0)*p(1)*fs(k)*anisoterm_a(i,_a,_b,xx);
J(i,2) += (p(2)>0?1.0:-1.0)*std::abs(p(1))*fs(k)*anisoterm_a(i,_a,_b,xx);
// beta
J(i,3) += (p(3)>0?1.0:-1.0)*p(1)*fs(k)*anisoterm_b(i,_a,_b,xx) * (-p(3)*p(3)); // change of variable beta=1/beta
J(i,3) += (p(3)>0?1.0:-1.0)*std::abs(p(1))*fs(k)*anisoterm_b(i,_a,_b,xx);
// f
J(i,kk) = p(1)*(anisoterm(i,_a,_b,xx)-isoterm(i,_a,_b)) * two_pi*sign(p(kk))*1/(1+p(kk)*p(kk));
J(i,kk) = std::abs(p(1))*(anisoterm(i,_a,_b,xx)-isoterm(i,_a,_b)) * two_pi*sign(p(kk))*1/(1+p(kk)*p(kk));
// th
J(i,kk+1) = p(1)*fs(k)*anisoterm_th(i,_a,_b,xx,p(kk+1),p(kk+2));
J(i,kk+1) = std::abs(p(1))*fs(k)*anisoterm_th(i,_a,_b,xx,p(kk+1),p(kk+2));
// ph
J(i,kk+2) = p(1)*fs(k)*anisoterm_ph(i,_a,_b,xx,p(kk+1),p(kk+2));
J(i,kk+2) = std::abs(p(1))*fs(k)*anisoterm_ph(i,_a,_b,xx,p(kk+1),p(kk+2));
}
sig = p(1)*((1-sumf)*isoterm(i,_a,_b)+sig);
sig = std::abs(p(1))*((1-sumf)*isoterm(i,_a,_b)+sig);
diff(i) = sig - Y(i);
// other stuff for derivatives
J(i,1) = sig/p(1);
J(i,2) += (p(2)>0?1.0:-1.0)*p(1)*(1-sumf)*isoterm_a(i,_a,_b);
J(i,3) += (p(3)>0?1.0:-1.0)*p(1)*(1-sumf)*isoterm_b(i,_a,_b) * (-p(3)*p(3));
J(i,1) = (p(1)>0?1.0:-1.0)*sig/p(1);
J(i,2) += (p(2)>0?1.0:-1.0)*std::abs(p(1))*(1-sumf)*isoterm_a(i,_a,_b);
J(i,3) += (p(3)>0?1.0:-1.0)*std::abs(p(1))*(1-sumf)*isoterm_b(i,_a,_b);
}
gradv = 2*J.t()*diff;
......@@ -961,22 +957,22 @@ boost::shared_ptr<BFMatrix> PVM_multi::hess(const NEWMAT::ColumnVector& p,boost:
// change of variable
float cov = two_pi*sign(p(kk))*1/(1+p(kk)*p(kk));
// alpha
J(i,2) += (p(2)>0?1.0:-1.0)*p(1)*fs(k)*anisoterm_a(i,_a,_b,xx);
J(i,2) += (p(2)>0?1.0:-1.0)*std::abs(p(1))*fs(k)*anisoterm_a(i,_a,_b,xx);
// beta
J(i,3) += (p(3)>0?1.0:-1.0)*p(1)*fs(k)*anisoterm_b(i,_a,_b,xx) * (-p(3)*p(3));
J(i,3) += (p(3)>0?1.0:-1.0)*std::abs(p(1))*fs(k)*anisoterm_b(i,_a,_b,xx);
// f
J(i,kk) = p(1)*(anisoterm(i,_a,_b,xx)-isoterm(i,_a,_b)) * cov;
J(i,kk) = std::abs(p(1))*(anisoterm(i,_a,_b,xx)-isoterm(i,_a,_b)) * cov;
// th
J(i,kk+1) = p(1)*fs(k)*anisoterm_th(i,_a,_b,xx,p(kk+1),p(kk+2));
J(i,kk+1) = std::abs(p(1))*fs(k)*anisoterm_th(i,_a,_b,xx,p(kk+1),p(kk+2));
// ph
J(i,kk+2) = p(1)*fs(k)*anisoterm_ph(i,_a,_b,xx,p(kk+1),p(kk+2));
J(i,kk+2) = std::abs(p(1))*fs(k)*anisoterm_ph(i,_a,_b,xx,p(kk+1),p(kk+2));
}
sig = p(1)*((1-sumf)*isoterm(i,_a,_b)+sig);
sig = std::abs(p(1))*((1-sumf)*isoterm(i,_a,_b)+sig);
diff(i) = sig - Y(i);
// other stuff for derivatives
J(i,1) = sig/p(1);
J(i,2) += (p(2)>0?1.0:-1.0)*p(1)*(1-sumf)*isoterm_a(i,_a,_b);
J(i,3) += (p(3)>0?1.0:-1.0)*p(1)*(1-sumf)*isoterm_b(i,_a,_b) * (-p(3)*p(3));
J(i,1) = (p(1)>0?1.0:-1.0)*sig/p(1);
J(i,2) += (p(2)>0?1.0:-1.0)*std::abs(p(1))*(1-sumf)*isoterm_a(i,_a,_b);
J(i,3) += (p(3)>0?1.0:-1.0)*std::abs(p(1))*(1-sumf)*isoterm_b(i,_a,_b);
}
......@@ -1084,7 +1080,7 @@ float PVM_multi::isoterm(const int& pt,const float& _a,const float& _b)const{
}
float PVM_multi::anisoterm(const int& pt,const float& _a,const float& _b,const ColumnVector& x)const{
float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
return(std::exp(-_a*std::log(1+bvals(1,pt)*_b*dp*dp)));
return(std::exp(-_a*std::log(1+bvals(1,pt)*_b*(dp*dp))));
}
// 1st order derivatives
float PVM_multi::isoterm_a(const int& pt,const float& _a,const float& _b)const{
......@@ -1095,20 +1091,20 @@ float PVM_multi::isoterm_b(const int& pt,const float& _a,const float& _b)const{
}
float PVM_multi::anisoterm_a(const int& pt,const float& _a,const float& _b,const ColumnVector& x)const{
float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
return(-std::log(1+bvals(1,pt)*dp*dp*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*dp*dp*_b)));
return(-std::log(1+bvals(1,pt)*(dp*dp)*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*(dp*dp)*_b)));
}
float PVM_multi::anisoterm_b(const int& pt,const float& _a,const float& _b,const ColumnVector& x)const{
float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
return(-_a*bvals(1,pt)*dp*dp/(1+bvals(1,pt)*dp*dp*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*dp*dp*_b)));
return(-_a*bvals(1,pt)*(dp*dp)/(1+bvals(1,pt)*(dp*dp)*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*(dp*dp)*_b)));
}
float PVM_multi::anisoterm_th(const int& pt,const float& _a,const float& _b,const ColumnVector& x,const float& _th,const float& _ph)const{
float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
float dp1 = cos(_th)*(bvecs(1,pt)*cos(_ph) + bvecs(2,pt)*sin(_ph)) - bvecs(3,pt)*sin(_th);
return(-_a*_b*bvals(1,pt)/(1+bvals(1,pt)*dp*dp*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*dp*dp*_b))*2*dp*dp1);
return(-_a*_b*bvals(1,pt)/(1+bvals(1,pt)*(dp*dp)*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*(dp*dp)*_b))*2*dp*dp1);
}
float PVM_multi::anisoterm_ph(const int& pt,const float& _a,const float& _b,const ColumnVector& x,const float& _th,const float& _ph)const{
float dp = bvecs(1,pt)*x(1)+bvecs(2,pt)*x(2)+bvecs(3,pt)*x(3);
float dp1 = sin(_th)*(-bvecs(1,pt)*sin(_ph) + bvecs(2,pt)*cos(_ph));
return(-_a*_b*bvals(1,pt)/(1+bvals(1,pt)*dp*dp*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*dp*dp*_b))*2*dp*dp1);
return(-_a*_b*bvals(1,pt)/(1+bvals(1,pt)*(dp*dp)*_b)*std::exp(-_a*std::log(1+bvals(1,pt)*(dp*dp)*_b))*2*dp*dp1);
}
......@@ -359,8 +359,12 @@ public:
cout << "D_STD :" << m_d_std << endl;
for(int i=1;i<=nfib;i++){
cout << "F" << i << " :" << m_f(i) << endl;
ColumnVector x(3);
x << sin(m_th(i))*cos(m_ph(i)) << sin(m_th(i))*sin(m_ph(i)) << cos(m_th(i));
if(x(3)<0)x=-x;
cout << "TH" << i << " :" << m_th(i) << endl;
cout << "PH" << i << " :" << m_ph(i) << endl;
cout << "DIR" << i << " : " << x(1) << " " << x(2) << " " << x(3) << endl;
}
}
void print(const ColumnVector& p)const{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment