From 807617573236e61b36994897b3d30c513b36e080 Mon Sep 17 00:00:00 2001
From: Mark Woolrich <woolrich@fmrib.ox.ac.uk>
Date: Tue, 20 Dec 2005 10:58:15 +0000
Subject: [PATCH] *** empty log message ***

---
 miscmaths.cc | 88 ++++++++++++++++++++++++++++++++++------------------
 miscmaths.h  |  2 +-
 quick.cc     |  4 +--
 3 files changed, 61 insertions(+), 33 deletions(-)

diff --git a/miscmaths.cc b/miscmaths.cc
index 80b9f3f..0b1a066 100644
--- a/miscmaths.cc
+++ b/miscmaths.cc
@@ -1977,43 +1977,65 @@ float csevl(const float x, const ColumnVector& cs, const int n)
     return psi;
   }
 
-
-  void glm_vb(const Matrix& design, const ColumnVector& Y, ColumnVector& m_B, SymmetricMatrix& ilambda_B)
+  void glm_vb(const Matrix& X, const ColumnVector& Y, ColumnVector& B, SymmetricMatrix& ilambda_B, int niters)
   {
+    // Does Variational Bayes inference on GLM Y=XB+e with ARD priors on B
+    // design matrix X should be num_tpts*num_evs
+
     /////////////////////
     // setup
     OUT("Setup");
 
     int ntpts=Y.Nrows();
-    int nevs=design.Nrows();
-    int niters=30;
+    int nevs=X.Ncols();
 
-    if(ntpts!=design.Ncols())
+    if(ntpts!=X.Nrows())
       throw Exception("COCK");
 
     OUT(nevs);
     OUT(ntpts);
 
-    m_B.ReSize(nevs);
-    m_B=0;
-    for(int i=1; i<=nevs; i++)
+    ColumnVector gam_m(nevs);
+    gam_m=1e10;
+    float gam_y;
+
+    ColumnVector lambdaB(nevs);
+    if(nevs<ntpts-10)
       {
-	m_B(i) = normrnd().AsScalar()*0.0001;
+	// initialise with OLS
+	B=pinv(X)*Y;
+	ColumnVector res=Y-X*B;
+	gam_y=(ntpts-nevs)/(res.t()*res).AsScalar();
+
+	ilambda_B << (X.t()*X*gam_y).i();
+	lambdaB=0;
+	for(int l=1; l <= nevs; l++)
+	  {
+	    lambdaB(l)=ilambda_B(l,l);
+	  }
       }
+    else
+      {
+	OUT("no ols");
+	B.ReSize(nevs);
+	B=0;
+	lambdaB=1;
 
-    ColumnVector gam_m(nevs);
-    gam_m=1e10;
+// 	ColumnVector res=Y-X*B;
+// 	gam_y=ntpts/(res.t()*res).AsScalar();
 
-    float gam_y=100;
+	gam_y=10;
+      }
+
+//     OUT(B(1));
+//     OUT(lambdaB(1));
 
-    ColumnVector lambdaB(nevs);
-    lambdaB=1;
     float trace_ilambdaZZ=1;
 
     SymmetricMatrix ZZ;
-    ZZ << design*design.t();
+    ZZ << X.t()*X;
 
-    Matrix ZY = design*Y;
+    Matrix ZY = X.t()*Y;
 
     float YY=0;
     for(int t=1; t <= ntpts; t++)
@@ -2026,19 +2048,21 @@ float csevl(const float x, const ColumnVector& cs, const int n)
     int i = 1;;
     for(; i<=niters; i++)
       {
-	OUT(i);
+	cout<<i<<",";
 	////////////////////
 	// update phim
 	for(int l=1; l <= nevs; l++)
 	  {
 	    float b_m0 = 1e10;
-	    float c_m0 = 1;
+	    float c_m0 = 2;
 
 	    float c_m = 1.0/2.0 + c_m0;	    
-	    float b_m = 1.0/(0.5*(Sqr(m_B(l))+lambdaB(l))+1.0/b_m0);
+	    float b_m = 1.0/(0.5*(Sqr(B(l))+lambdaB(l))+1.0/b_m0);
 	    gam_m(l) = b_m*c_m;	    
 	  }
 
+// 	OUT(gam_m(1));
+
 	////////////////////
 	// update B
 	ColumnVector beta(nevs);
@@ -2055,8 +2079,7 @@ float csevl(const float x, const ColumnVector& cs, const int n)
 	beta += gam_y*ZY;
 
 	ilambda_B << lambda_B.i();
-	m_B=ilambda_B*beta;
-	OUT(m_B(1));
+	B=ilambda_B*beta;
 
 	lambdaB.ReSize(nevs);
 	lambdaB=0;
@@ -2075,22 +2098,27 @@ float csevl(const float x, const ColumnVector& cs, const int n)
 	tmp2 << tmp3*ZZ;
 	
 	trace_ilambdaZZ=tmp2.Trace();	
-      }
+// 	OUT(trace_ilambdaZZ);
 
 
-    /////////////////////
-    // update phiy
-    float b_y0 = 1e10;
-    float c_y0 = 0;
+	/////////////////////
+	// update phiy
+	float b_y0 = 1e10;
+	float c_y0 = 1;
 
-    float c_y = (ntpts-1)/2.0 + c_y0;
+	float c_y = (ntpts-1)/2.0 + c_y0;
 
-    float sum = YY + (m_B.t()*ZZ*m_B).AsScalar() - 2*(m_B.t()*ZY).AsScalar();
+	float sum = YY + (B.t()*ZZ*B).AsScalar() - 2*(B.t()*ZY).AsScalar();
 	
-    float b_y = 1.0/(0.5*(sum + trace_ilambdaZZ)+1/b_y0);
+	float b_y = 1.0/(0.5*(sum + trace_ilambdaZZ)+1/b_y0);
 	
-    gam_y = b_y*c_y;	     
+	gam_y = b_y*c_y;
+
+// 	OUT(gam_y);	     
+
+      }
 
+    cout << endl;
   }
 
 vector<float> ColumnVector2vector(const ColumnVector& col)
diff --git a/miscmaths.h b/miscmaths.h
index 9a9c9a1..27d36ff 100644
--- a/miscmaths.h
+++ b/miscmaths.h
@@ -224,7 +224,7 @@ namespace MISCMATHS {
 
   float csevl(const float x, const ColumnVector& cs, const int n);
   float digamma(const float x);
-  void glm_vb(const Matrix& design, const ColumnVector& Y, ColumnVector& m_B, SymmetricMatrix& ilambda_B);
+  void glm_vb(const Matrix& X, const ColumnVector& Y, ColumnVector& B, SymmetricMatrix& ilambda_B, int niters=20);
 
   vector<float> ColumnVector2vector(const ColumnVector& col);
   
diff --git a/quick.cc b/quick.cc
index a6d2e9b..847118f 100644
--- a/quick.cc
+++ b/quick.cc
@@ -23,14 +23,14 @@ int main(int argc, char *argv[])
 {
   try{
 
-    Matrix design = read_vest("/usr/people/woolrich/matlab/vbbabe/data/design2.mat");
+    Matrix X = read_vest("/usr/people/woolrich/matlab/vbbabe/data/design2.mat").t();
  
     ColumnVector Y = read_vest("/usr/people/woolrich/matlab/vbbabe/data/sdf2.mat").t();
  
     ColumnVector m_B;
     SymmetricMatrix ilambda_B;
 
-    glm_vb(design, Y, m_B, ilambda_B);
+    glm_vb(X, Y, m_B, ilambda_B, 30);
 
     write_ascii_matrix(m_B,"/usr/people/woolrich/matlab/vbbabe/data/m_B");
     
-- 
GitLab