Skip to content
Snippets Groups Projects
Commit 97d1f4ff authored by Mark Woolrich's avatar Mark Woolrich
Browse files

*** empty log message ***

parent 4f95007d
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
// Miscellaneous maths functions // Miscellaneous maths functions
#include "miscmaths.h" #include "miscmaths.h"
#include "miscprob.h"
#include "stdlib.h" #include "stdlib.h"
#include "newmatio.h" #include "newmatio.h"
...@@ -1977,6 +1978,121 @@ float csevl(const float x, const ColumnVector& cs, const int n) ...@@ -1977,6 +1978,121 @@ float csevl(const float x, const ColumnVector& cs, const int n)
} }
void glm_vb(const Matrix& design, const ColumnVector& Y, ColumnVector& m_B, SymmetricMatrix& ilambda_B)
{
/////////////////////
// setup
OUT("Setup");
int ntpts=Y.Nrows();
int nevs=design.Nrows();
int niters=30;
if(ntpts!=design.Ncols())
throw Exception("COCK");
OUT(nevs);
OUT(ntpts);
m_B.ReSize(nevs);
m_B=0;
for(int i=1; i<=nevs; i++)
{
m_B(i) = normrnd().AsScalar()*0.0001;
}
ColumnVector gam_m(nevs);
gam_m=1e10;
float gam_y=100;
ColumnVector lambdaB(nevs);
lambdaB=1;
float trace_ilambdaZZ=1;
SymmetricMatrix ZZ;
ZZ << design*design.t();
Matrix ZY = design*Y;
float YY=0;
for(int t=1; t <= ntpts; t++)
YY += Sqr(Y(t));
/////////////////////
// iterate
OUT("Iterate");
int i = 1;;
for(; i<=niters; i++)
{
OUT(i);
////////////////////
// update phim
for(int l=1; l <= nevs; l++)
{
float b_m0 = 1e10;
float c_m0 = 1;
float c_m = 1.0/2.0 + c_m0;
float b_m = 1.0/(0.5*(Sqr(m_B(l))+lambdaB(l))+1.0/b_m0);
gam_m(l) = b_m*c_m;
}
////////////////////
// update B
ColumnVector beta(nevs);
beta = 0;
SymmetricMatrix lambda_B(nevs);
lambda_B = 0;
for(int l=1; l <= nevs; l++)
lambda_B(l,l)=gam_m(l);
SymmetricMatrix tmp = lambda_B + gam_y*ZZ;
lambda_B << tmp;
beta += gam_y*ZY;
ilambda_B << lambda_B.i();
m_B=ilambda_B*beta;
OUT(m_B(1));
lambdaB.ReSize(nevs);
lambdaB=0;
for(int l=1; l <= nevs; l++)
{
lambdaB(l)=ilambda_B(l,l);
}
////////////////////
// compute trace for noise precision phiy update
SymmetricMatrix tmp3;
tmp3 << ilambda_B;
SymmetricMatrix tmp2;
tmp2 << tmp3*ZZ;
trace_ilambdaZZ=tmp2.Trace();
}
/////////////////////
// update phiy
float b_y0 = 1e10;
float c_y0 = 0;
float c_y = (ntpts-1)/2.0 + c_y0;
float sum = YY + (m_B.t()*ZZ*m_B).AsScalar() - 2*(m_B.t()*ZY).AsScalar();
float b_y = 1.0/(0.5*(sum + trace_ilambdaZZ)+1/b_y0);
gam_y = b_y*c_y;
}
vector<float> ColumnVector2vector(const ColumnVector& col) vector<float> ColumnVector2vector(const ColumnVector& col)
{ {
vector<float> vec(col.Nrows()); vector<float> vec(col.Nrows());
...@@ -1995,7 +2111,7 @@ typedef struct { unsigned char a,b ; } TWObytes ; ...@@ -1995,7 +2111,7 @@ typedef struct { unsigned char a,b ; } TWObytes ;
void Swap_2bytes( int n , void *ar ) /* 2 bytes at a time */ void Swap_2bytes( int n , void *ar ) /* 2 bytes at a time */
{ {
register int ii ; register int ii ;
register TWObytes *tb = (TWObytes *)ar ; register TWObytes *tb = (TWObytes *)ar ;
register unsigned char tt ; register unsigned char tt ;
......
...@@ -224,6 +224,7 @@ namespace MISCMATHS { ...@@ -224,6 +224,7 @@ namespace MISCMATHS {
float csevl(const float x, const ColumnVector& cs, const int n); float csevl(const float x, const ColumnVector& cs, const int n);
float digamma(const float x); float digamma(const float x);
void glm_vb(const Matrix& design, const ColumnVector& Y, ColumnVector& m_B, SymmetricMatrix& ilambda_B);
vector<float> ColumnVector2vector(const ColumnVector& col); vector<float> ColumnVector2vector(const ColumnVector& col);
......
...@@ -22,17 +22,18 @@ using namespace MISCMATHS; ...@@ -22,17 +22,18 @@ using namespace MISCMATHS;
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
try{ try{
float tmp = atof(argv[1]);
int tmp2 = atoi(argv[2]);
OUT(tmp); Matrix design = read_vest("/usr/people/woolrich/matlab/vbbabe/data/design2.mat");
OUT(tmp2);
ColumnVector Y = read_vest("/usr/people/woolrich/matlab/vbbabe/data/sdf2.mat").t();
ColumnVector m_B;
SymmetricMatrix ilambda_B;
OUT(T2z::getInstance().converttologp(tmp,tmp2)); glm_vb(design, Y, m_B, ilambda_B);
OUT(std::exp(T2z::getInstance().converttologp(tmp,tmp2)));
OUT(T2z::getInstance().convert(tmp,tmp2));
write_ascii_matrix(m_B,"/usr/people/woolrich/matlab/vbbabe/data/m_B");
} }
catch(Exception p_excp) catch(Exception p_excp)
{ {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment