From 97d1f4ff2c47ce7e5395a02a28b0d90285671a09 Mon Sep 17 00:00:00 2001
From: Mark Woolrich <woolrich@fmrib.ox.ac.uk>
Date: Mon, 19 Dec 2005 15:54:42 +0000
Subject: [PATCH] *** empty log message ***

---
 miscmaths.cc | 118 ++++++++++++++++++++++++++++++++++++++++++++++++++-
 miscmaths.h  |   1 +
 quick.cc     |  17 ++++----
 3 files changed, 127 insertions(+), 9 deletions(-)

diff --git a/miscmaths.cc b/miscmaths.cc
index b63a5b3..80b9f3f 100644
--- a/miscmaths.cc
+++ b/miscmaths.cc
@@ -9,6 +9,7 @@
 // Miscellaneous maths functions
 
 #include "miscmaths.h"
+#include "miscprob.h"
 #include "stdlib.h"
 #include "newmatio.h"
 
@@ -1977,6 +1978,121 @@ float csevl(const float x, const ColumnVector& cs, const int n)
   }
 
 
+  void glm_vb(const Matrix& design, const ColumnVector& Y, ColumnVector& m_B, SymmetricMatrix& ilambda_B)
+  {
+    /////////////////////
+    // setup
+    OUT("Setup");
+
+    int ntpts=Y.Nrows();
+    int nevs=design.Nrows();
+    int niters=30;
+
+    if(ntpts!=design.Ncols())
+      throw Exception("COCK");
+
+    OUT(nevs);
+    OUT(ntpts);
+
+    m_B.ReSize(nevs);
+    m_B=0;
+    for(int i=1; i<=nevs; i++)
+      {
+	m_B(i) = normrnd().AsScalar()*0.0001;
+      }
+
+    ColumnVector gam_m(nevs);
+    gam_m=1e10;
+
+    float gam_y=100;
+
+    ColumnVector lambdaB(nevs);
+    lambdaB=1;
+    float trace_ilambdaZZ=1;
+
+    SymmetricMatrix ZZ;
+    ZZ << design*design.t();
+
+    Matrix ZY = design*Y;
+
+    float YY=0;
+    for(int t=1; t <= ntpts; t++)
+      YY += Sqr(Y(t));
+
+    /////////////////////
+    // iterate
+    OUT("Iterate");
+
+    int i = 1;;
+    for(; i<=niters; i++)
+      {
+	OUT(i);
+	////////////////////
+	// update phim
+	for(int l=1; l <= nevs; l++)
+	  {
+	    float b_m0 = 1e10;
+	    float c_m0 = 1;
+
+	    float c_m = 1.0/2.0 + c_m0;	    
+	    float b_m = 1.0/(0.5*(Sqr(m_B(l))+lambdaB(l))+1.0/b_m0);
+	    gam_m(l) = b_m*c_m;	    
+	  }
+
+	////////////////////
+	// update B
+	ColumnVector beta(nevs);
+	beta = 0;
+	SymmetricMatrix lambda_B(nevs);
+	lambda_B = 0;
+
+	for(int l=1; l <= nevs; l++)
+	  lambda_B(l,l)=gam_m(l);
+
+	SymmetricMatrix tmp = lambda_B + gam_y*ZZ;
+	lambda_B << tmp;
+
+	beta += gam_y*ZY;
+
+	ilambda_B << lambda_B.i();
+	m_B=ilambda_B*beta;
+	OUT(m_B(1));
+
+	lambdaB.ReSize(nevs);
+	lambdaB=0;
+	for(int l=1; l <= nevs; l++)
+	  {
+	    lambdaB(l)=ilambda_B(l,l);
+	  }
+	
+	////////////////////
+	// compute trace for noise precision phiy update
+	
+	SymmetricMatrix tmp3;
+	tmp3 << ilambda_B;
+	
+	SymmetricMatrix tmp2;
+	tmp2 << tmp3*ZZ;
+	
+	trace_ilambdaZZ=tmp2.Trace();	
+      }
+
+
+    /////////////////////
+    // update phiy
+    float b_y0 = 1e10;
+    float c_y0 = 0;
+
+    float c_y = (ntpts-1)/2.0 + c_y0;
+
+    float sum = YY + (m_B.t()*ZZ*m_B).AsScalar() - 2*(m_B.t()*ZY).AsScalar();
+	
+    float b_y = 1.0/(0.5*(sum + trace_ilambdaZZ)+1/b_y0);
+	
+    gam_y = b_y*c_y;	     
+
+  }
+
 vector<float> ColumnVector2vector(const ColumnVector& col)
 {
   vector<float> vec(col.Nrows());
@@ -1995,7 +2111,7 @@ typedef struct { unsigned char a,b ; } TWObytes ;
 
 void Swap_2bytes( int n , void *ar )    /* 2 bytes at a time */
 {
-   register int ii ;
+  register int ii ;
    register TWObytes *tb = (TWObytes *)ar ;
    register unsigned char tt ;
 
diff --git a/miscmaths.h b/miscmaths.h
index 8ad0458..9a9c9a1 100644
--- a/miscmaths.h
+++ b/miscmaths.h
@@ -224,6 +224,7 @@ namespace MISCMATHS {
 
   float csevl(const float x, const ColumnVector& cs, const int n);
   float digamma(const float x);
+  void glm_vb(const Matrix& design, const ColumnVector& Y, ColumnVector& m_B, SymmetricMatrix& ilambda_B);
 
   vector<float> ColumnVector2vector(const ColumnVector& col);
   
diff --git a/quick.cc b/quick.cc
index 7601550..a6d2e9b 100644
--- a/quick.cc
+++ b/quick.cc
@@ -22,17 +22,18 @@ using namespace MISCMATHS;
 int main(int argc, char *argv[])
 {
   try{
-   
-    float tmp = atof(argv[1]);
-    int tmp2 = atoi(argv[2]);
 
-    OUT(tmp);
-    OUT(tmp2);
+    Matrix design = read_vest("/usr/people/woolrich/matlab/vbbabe/data/design2.mat");
+ 
+    ColumnVector Y = read_vest("/usr/people/woolrich/matlab/vbbabe/data/sdf2.mat").t();
+ 
+    ColumnVector m_B;
+    SymmetricMatrix ilambda_B;
 
-    OUT(T2z::getInstance().converttologp(tmp,tmp2));
-    OUT(std::exp(T2z::getInstance().converttologp(tmp,tmp2)));
-    OUT(T2z::getInstance().convert(tmp,tmp2));
+    glm_vb(design, Y, m_B, ilambda_B);
 
+    write_ascii_matrix(m_B,"/usr/people/woolrich/matlab/vbbabe/data/m_B");
+    
   }
   catch(Exception p_excp) 
     {
-- 
GitLab