From 807617573236e61b36994897b3d30c513b36e080 Mon Sep 17 00:00:00 2001 From: Mark Woolrich <woolrich@fmrib.ox.ac.uk> Date: Tue, 20 Dec 2005 10:58:15 +0000 Subject: [PATCH] *** empty log message *** --- miscmaths.cc | 88 ++++++++++++++++++++++++++++++++++------------------ miscmaths.h | 2 +- quick.cc | 4 +-- 3 files changed, 61 insertions(+), 33 deletions(-) diff --git a/miscmaths.cc b/miscmaths.cc index 80b9f3f..0b1a066 100644 --- a/miscmaths.cc +++ b/miscmaths.cc @@ -1977,43 +1977,65 @@ float csevl(const float x, const ColumnVector& cs, const int n) return psi; } - - void glm_vb(const Matrix& design, const ColumnVector& Y, ColumnVector& m_B, SymmetricMatrix& ilambda_B) + void glm_vb(const Matrix& X, const ColumnVector& Y, ColumnVector& B, SymmetricMatrix& ilambda_B, int niters) { + // Does Variational Bayes inference on GLM Y=XB+e with ARD priors on B + // design matrix X should be num_tpts*num_evs + ///////////////////// // setup OUT("Setup"); int ntpts=Y.Nrows(); - int nevs=design.Nrows(); - int niters=30; + int nevs=X.Ncols(); - if(ntpts!=design.Ncols()) + if(ntpts!=X.Nrows()) throw Exception("COCK"); OUT(nevs); OUT(ntpts); - m_B.ReSize(nevs); - m_B=0; - for(int i=1; i<=nevs; i++) + ColumnVector gam_m(nevs); + gam_m=1e10; + float gam_y; + + ColumnVector lambdaB(nevs); + if(nevs<ntpts-10) { - m_B(i) = normrnd().AsScalar()*0.0001; + // initialise with OLS + B=pinv(X)*Y; + ColumnVector res=Y-X*B; + gam_y=(ntpts-nevs)/(res.t()*res).AsScalar(); + + ilambda_B << (X.t()*X*gam_y).i(); + lambdaB=0; + for(int l=1; l <= nevs; l++) + { + lambdaB(l)=ilambda_B(l,l); + } } + else + { + OUT("no ols"); + B.ReSize(nevs); + B=0; + lambdaB=1; - ColumnVector gam_m(nevs); - gam_m=1e10; +// ColumnVector res=Y-X*B; +// gam_y=ntpts/(res.t()*res).AsScalar(); - float gam_y=100; + gam_y=10; + } + +// OUT(B(1)); +// OUT(lambdaB(1)); - ColumnVector lambdaB(nevs); - lambdaB=1; float trace_ilambdaZZ=1; SymmetricMatrix ZZ; - ZZ << design*design.t(); + ZZ << X.t()*X; - Matrix ZY = design*Y; + Matrix ZY = X.t()*Y; float YY=0; for(int t=1; t <= ntpts; t++) @@ -2026,19 +2048,21 @@ float csevl(const float x, const ColumnVector& cs, const int n) int i = 1;; for(; i<=niters; i++) { - OUT(i); + cout<<i<<","; //////////////////// // update phim for(int l=1; l <= nevs; l++) { float b_m0 = 1e10; - float c_m0 = 1; + float c_m0 = 2; float c_m = 1.0/2.0 + c_m0; - float b_m = 1.0/(0.5*(Sqr(m_B(l))+lambdaB(l))+1.0/b_m0); + float b_m = 1.0/(0.5*(Sqr(B(l))+lambdaB(l))+1.0/b_m0); gam_m(l) = b_m*c_m; } +// OUT(gam_m(1)); + //////////////////// // update B ColumnVector beta(nevs); @@ -2055,8 +2079,7 @@ float csevl(const float x, const ColumnVector& cs, const int n) beta += gam_y*ZY; ilambda_B << lambda_B.i(); - m_B=ilambda_B*beta; - OUT(m_B(1)); + B=ilambda_B*beta; lambdaB.ReSize(nevs); lambdaB=0; @@ -2075,22 +2098,27 @@ float csevl(const float x, const ColumnVector& cs, const int n) tmp2 << tmp3*ZZ; trace_ilambdaZZ=tmp2.Trace(); - } +// OUT(trace_ilambdaZZ); - ///////////////////// - // update phiy - float b_y0 = 1e10; - float c_y0 = 0; + ///////////////////// + // update phiy + float b_y0 = 1e10; + float c_y0 = 1; - float c_y = (ntpts-1)/2.0 + c_y0; + float c_y = (ntpts-1)/2.0 + c_y0; - float sum = YY + (m_B.t()*ZZ*m_B).AsScalar() - 2*(m_B.t()*ZY).AsScalar(); + float sum = YY + (B.t()*ZZ*B).AsScalar() - 2*(B.t()*ZY).AsScalar(); - float b_y = 1.0/(0.5*(sum + trace_ilambdaZZ)+1/b_y0); + float b_y = 1.0/(0.5*(sum + trace_ilambdaZZ)+1/b_y0); - gam_y = b_y*c_y; + gam_y = b_y*c_y; + +// OUT(gam_y); + + } + cout << endl; } vector<float> ColumnVector2vector(const ColumnVector& col) diff --git a/miscmaths.h b/miscmaths.h index 9a9c9a1..27d36ff 100644 --- a/miscmaths.h +++ b/miscmaths.h @@ -224,7 +224,7 @@ namespace MISCMATHS { float csevl(const float x, const ColumnVector& cs, const int n); float digamma(const float x); - void glm_vb(const Matrix& design, const ColumnVector& Y, ColumnVector& m_B, SymmetricMatrix& ilambda_B); + void glm_vb(const Matrix& X, const ColumnVector& Y, ColumnVector& B, SymmetricMatrix& ilambda_B, int niters=20); vector<float> ColumnVector2vector(const ColumnVector& col); diff --git a/quick.cc b/quick.cc index a6d2e9b..847118f 100644 --- a/quick.cc +++ b/quick.cc @@ -23,14 +23,14 @@ int main(int argc, char *argv[]) { try{ - Matrix design = read_vest("/usr/people/woolrich/matlab/vbbabe/data/design2.mat"); + Matrix X = read_vest("/usr/people/woolrich/matlab/vbbabe/data/design2.mat").t(); ColumnVector Y = read_vest("/usr/people/woolrich/matlab/vbbabe/data/sdf2.mat").t(); ColumnVector m_B; SymmetricMatrix ilambda_B; - glm_vb(design, Y, m_B, ilambda_B); + glm_vb(X, Y, m_B, ilambda_B, 30); write_ascii_matrix(m_B,"/usr/people/woolrich/matlab/vbbabe/data/m_B"); -- GitLab