From 0ffe8a68be202ace5739acfc3d0d620a99ad34d9 Mon Sep 17 00:00:00 2001 From: Saad Jbabdi <saad@fmrib.ox.ac.uk> Date: Sun, 29 Jul 2007 16:15:44 +0000 Subject: [PATCH] remove dpm --- dpm.cc | 40 ----- dpmOptions.cc | 28 ---- dpmOptions.h | 107 ------------- dpm_gibbs.cc | 430 -------------------------------------------------- dpm_gibbs.h | 318 ------------------------------------- gibbs.h | 74 --------- 6 files changed, 997 deletions(-) delete mode 100644 dpm.cc delete mode 100644 dpmOptions.cc delete mode 100644 dpmOptions.h delete mode 100644 dpm_gibbs.cc delete mode 100644 dpm_gibbs.h delete mode 100644 gibbs.h diff --git a/dpm.cc b/dpm.cc deleted file mode 100644 index 815b971..0000000 --- a/dpm.cc +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (C) 2007 University of Oxford */ - -/* S. Jbabdi */ - -/* CCOPYRIGHT */ - -#include <stdio.h> -#include "dpm_gibbs.h" - -using namespace DPM; -using namespace Utilities; - - -int main (int argc, char *argv[]){ - - Log& logger = LogSingleton::getInstance(); - - dpmOptions& opts = dpmOptions::getInstance(); - opts.parse_command_line(argc,argv,logger); - - // read input files - Matrix data; // data is nxd - data=read_ascii_matrix(opts.datafile.value()); - - - // create Gibb's sampler instance - //cout <<"instanciate DPM_GibbsSampler"<<endl; - DPM_GibbsSampler gs(data, - opts.numiter.value(),opts.burnin.value(),opts.sampleevery.value()); - - //cout<<"initialisation"<<endl; - gs.init(); - //cout<<"running..."<<endl; - gs.run(); - - - // save output files - gs.save(); -} - diff --git a/dpmOptions.cc b/dpmOptions.cc deleted file mode 100644 index bf60965..0000000 --- a/dpmOptions.cc +++ /dev/null @@ -1,28 +0,0 @@ -#define WANT_STREAM -#define WANT_MATH - -#include "dpmOptions.h" - -using namespace Utilities; - -namespace DPM { - -dpmOptions* dpmOptions::gopt = NULL; - - void dpmOptions::parse_command_line(int argc, char** argv, Log& logger) - { - // do once to establish log directory name - for(int a = options.parse_command_line(argc, argv); a < argc; a++); - - - if(help.value() || ! options.check_compulsory_arguments()) - { - options.usage(); - //throw Exception("Not all of the compulsory arguments have been provided"); - exit(2); - } - - - } - -} diff --git a/dpmOptions.h b/dpmOptions.h deleted file mode 100644 index d069c66..0000000 --- a/dpmOptions.h +++ /dev/null @@ -1,107 +0,0 @@ -#if !defined(dpmOptions_h) -#define dpmOptions_h - -#include <string> -#include <iostream> -#include <fstream> -#include <stdlib.h> -#include <stdio.h> -#include "utils/options.h" -#include "utils/log.h" - -using namespace Utilities; - -namespace DPM { - -class dpmOptions { - public: - static dpmOptions& getInstance(); - ~dpmOptions() { delete gopt; } - - Option<bool> help; - Option<bool> verbose; - - Option<string> datafile; - Option<string> logfile; - Option<string> init_class; - Option<int> numclass; - - Option<int> numiter; - Option<int> burnin; - Option<int> sampleevery; - void parse_command_line(int argc,char** argv,Log& logger); - - private: - dpmOptions(); - const dpmOptions& operator=(dpmOptions&); - dpmOptions(dpmOptions&); - - OptionParser options; - - static dpmOptions* gopt; - -}; - - - inline dpmOptions& dpmOptions::getInstance(){ - if(gopt == NULL) - gopt = new dpmOptions(); - - return *gopt; - } - - inline dpmOptions::dpmOptions() : - help(string("-h,--help"), false, - string("display this message"), - false,no_argument), - verbose(string("-V,--verbose"), false, - string("display program outputs"), - false,no_argument), - datafile(string("-d,--data"), string(""), - string("data file"), - true,requires_argument), - logfile(string("-o,--out"), string(""), - string("output file"), - true, requires_argument), - init_class(string("--ic,--initclass"), "random", - string("data labelling initialisation"), - false, requires_argument), - numclass(string("-k,--numclass"),-1, - string("fix number of classes - default=infinite"), - false,requires_argument), - numiter(string("--ni,--numiter"),2000, - string("number of iterations - default=2000"), - false,requires_argument), - burnin(string("--bi,--burnin"),1000, - string("number of iterations before sampling - default=1000"), - false,requires_argument), - sampleevery(string("--se,--sampleevery"),1, - string("sampling frequency - default=1"), - false,requires_argument), - options("dpm","dpm -d data -o logfile") - { - - - try { - options.add(help); - options.add(verbose); - options.add(datafile); - options.add(logfile); - options.add(init_class); - options.add(numclass); - options.add(numiter); - options.add(burnin); - options.add(sampleevery); - } - catch(X_OptionError& e) { - options.usage(); - cerr << endl << e.what() << endl; - } - catch(std::exception &e) { - cerr << e.what() << endl; - } - - } -} - -#endif diff --git a/dpm_gibbs.cc b/dpm_gibbs.cc deleted file mode 100644 index 30f205e..0000000 --- a/dpm_gibbs.cc +++ /dev/null @@ -1,430 +0,0 @@ -#include "dpm_gibbs.h" - -bool compare(const pair<float,int> &r1,const pair<float,int> &r2){ - return (r1.first<r2.first); -} - -void randomise(vector< pair<float,int> >& r){ - for(unsigned int i=0;i<r.size();i++){ - pair<float,int> p(rand()/float(RAND_MAX),i); - r[i]=p; - } - sort(r.begin(),r.end(),compare); - -} -std::ostream& operator << (ostream& o,DPM_GibbsSampler& g){ - g.print(o); - return o; -} -std::ostream& operator << (ostream& o,GaussianWishart& g){ - g.print(o); - return o; -} - - -void DPM_GibbsSampler::init(){ - // fix fixed parameters - m_a0 = 1.0; - m_b0 = 1.0E8; - m_S0 << 1000.0*Identity(m_d);//cov(m_data); - m_N0 << Identity(m_d);///(m_nu0-m_d-1); - m_m0 = mean(m_data,1).t(); - m_n0 = 1; - - // initialise all other parameters - m_alpha = 1.0; - m_k = opts.numclass.value(); - - // class hyper parameters - float kappa0 = 1.0; - int nu0 = m_d; - SymmetricMatrix Nu0(m_d); - Nu0 << m_n0*m_N0;//.01*m_d*Identity(m_d);//cov(m_data);//*(m_nu0-m_d-1); - ColumnVector mu0(m_d); - mu0 = m_m0; - - m_gw0 = GaussianWishart(mu0,Nu0,nu0,kappa0); - - // class parameters - if(opts.numclass.value() < 0){ // infinite mixture case - if(opts.init_class.value() == "oneperdata"){ - if(opts.verbose.value()) - cout << "Initialise with one class per data"<<endl; - init_oneperdata(); - } - else if (opts.init_class.value() == "one"){ - if(opts.verbose.value()) - cout << "initialise with one big class"<<endl; - init_onebigclass(); - } - else if (opts.init_class.value() == "kmeans"){ - if(opts.verbose.value()) - cout << "Initialise using kmeans" << endl; - init_kmeans(); - } - else{ // random - cout << "Random initialisation using 10 classes" << endl; - init_random(); - } - } - else{ // finite mixture case - init_kmeans(opts.numclass.value()); - } - - // calculate part of the marginalisation over class mean/variance - // this part doesn't change through the iterations - m_margintbase = m_d/2*(log(m_gw0.get_kappa()/(1+m_gw0.get_kappa()))-log(M_PI)) - + lgam(float(nu0+1)/2.0) -lgam(float(nu0+1-m_d)/2.0); - - // randomised index for loop over data items - randindex.resize(m_n); - - //cout << *this; -} -// different initialisation schemes -void DPM_GibbsSampler::init_oneperdata(){ - m_k = m_n; - // set parameters - for(int i=1;i<=m_n;i++){ - GaussianWishart gw(m_d); - gw.postupdate(m_data.SubMatrix(i,i,1,m_d),m_gw0); - m_gw.push_back(gw); - m_z.push_back(i-1); - m_classnd.push_back(1); - } -} -void DPM_GibbsSampler::init_onebigclass(){ - m_k = 1; - GaussianWishart gw(m_d); - gw.postupdate(m_data,m_gw0); - m_gw.push_back(gw); - for(int i=0;i<m_data.Nrows();i++)m_z.push_back(0); - m_classnd.push_back(m_data.Nrows()); -} -void DPM_GibbsSampler::init_kmeans(const int k){ - m_k=k; - m_z.resize(m_n); - do_kmeans(); - for(int k=1;k<=m_k;k++){ - GaussianWishart gw(m_d); - vector<ColumnVector> dat; - for(int i=1;i<=m_n;i++) - if(m_z[i-1] == k){ - dat.push_back(m_data.Row(i).t()); - m_z[i-1] -- ; - } - gw.postupdate(dat,m_gw0); - m_gw.push_back(gw); - m_classnd.push_back((int)dat.size()); - } -} -void DPM_GibbsSampler::init_random(const int k){ - m_k=k; - m_z.resize(m_n); - vector< pair<float,int> > rindex(m_n); - randomise(rindex); - vector<pair<float,int> >::iterator riter; - int nn=0,cl=1,nperclass=(int)(float(m_n)/float(m_k)); - for(riter=rindex.begin();riter!=rindex.end();++riter){ - m_z[(*riter).second]=cl; - nn++; - if(nn>=nperclass && cl<m_k){ - nn=0; - cl++; - } - } - for(int k=1;k<=m_k;k++){ - GaussianWishart gw(m_d); - vector<ColumnVector> dat; - for(int i=1;i<=m_n;i++) - if(m_z[i-1] == k){ - dat.push_back(m_data.Row(i).t()); - m_z[i-1] -- ; - } - gw.postupdate(dat,m_gw0); - m_gw.push_back(gw); - m_classnd.push_back((int)dat.size()); - } -} - - -void DPM_GibbsSampler::sample_parameters(){ - cout << *this; - - // sample indicators - //cout<<"sample z"<<endl; - sample_z(); - // sample mean and variance of each class - //cout<<"sample gw"<<endl; - sample_gw(); -} -void DPM_GibbsSampler::sample_hyperparameters(){ - // sample hyperpriors - //cout<<"sample gw0"<<endl; - sample_gw0(); - // sample alpha - //cout<<"sample alpha"<<endl; - sample_alpha(); -} -// sample indicator variables -void DPM_GibbsSampler::sample_z(){ - ColumnVector datapoint(m_d); - randomise(randindex); - - // if finite gaussian mixture, do not add new classes - float extra_finite1 = opts.numclass.value() < 0 ? 0.0 : m_alpha/float(m_k); - float extra_finite2 = opts.numclass.value() < 0 ? 1.0 : 0.0; - - vector< pair<float,int> >::iterator iter; - for(iter=randindex.begin(); iter!=randindex.end(); ++iter){ - ColumnVector cumsum(m_k+1); - ColumnVector w(m_k+1); - - datapoint=m_data.Row((*iter).second+1).t(); - int oldz=m_z[(*iter).second],newz=oldz; - m_classnd[oldz] -= 1; - - // compute class weights - double sum=0.0; - for(int k=0;k<m_k;k++){ - w(k+1) = exp(log(m_classnd[k]+extra_finite1)+marglik(datapoint,k)); - sum += exp(log(m_classnd[k]+extra_finite1)+marglik(datapoint,k)); - cumsum(k+1) = sum; - } - w(m_k+1) = m_alpha*exp(margint(datapoint)) * extra_finite2; - sum += m_alpha*exp(margint(datapoint)) * extra_finite2; - cumsum(m_k+1) = sum; - // sample z using the weights - float U=rand()/float(RAND_MAX); - U *= sum; - for(int k=1;k<=m_k+1;k++){ - if(U<cumsum(k)){ - newz=k-1; - break; - } - } - m_z[(*iter).second] = newz; - - if( newz >= m_k ){ // add a new class - m_k++; - m_classnd.push_back(1); - GaussianWishart gw(m_d); - gw.postupdate(datapoint.t(),m_gw0); - m_gw.push_back(gw); - } - else{ - m_classnd[newz] += 1; - } - //cout << " chosen cluster: "<<(*iter).second<<",oldz="<<oldz<<",newz="<<newz; - //cout << ",w="<<w(newz+1)<<",nold="<<m_classnd[oldz]<<"n="<<m_classnd[newz]<<endl; - - }// end loop over data points - //cout<<"end data"<<endl<<endl; - - // delete empty classes if in infinite mode - if(opts.numclass.value()<0){ - for(int k=m_k-1;k>=0;k--){ - if(m_classnd[k] == 0){ - for(int i=0;i<m_n;i++) - if(m_z[i]>k)m_z[i]--; - for(int kk=k;kk<m_k-1;kk++){ - m_classnd[kk]=m_classnd[kk+1]; - m_gw[kk]=m_gw[kk+1]; - } - m_classnd.pop_back(); - m_gw.pop_back(); - m_k--; - } - } - } -} - -void DPM_GibbsSampler::sample_gw(){ - // update classes posteriors - vector< vector<ColumnVector> > data; - data.resize(m_k); - - // calculate likelihood - m_likelihood = 0; - for(int i=0;i<m_n;i++){ - data[ m_z[i] ].push_back(m_data.Row(i+1).t()); - m_likelihood += -marglik(m_data.Row(i+1).t(),m_z[i]); - } - - for(int k=0;k<m_k;k++){ - if(data[k].size()>0) - m_gw[k].postupdate(data[k],m_gw0); - } -} - - -void DPM_GibbsSampler::sample_gw0(){ - SymmetricMatrix Nu0(m_d),A(m_d),S(m_d); - ColumnVector a(m_d),mu0(m_d); - float B=0; - - A=0;a=0; - for(int k=0;k<m_k;k++){ - S = m_gw[k].get_ssigma().i(); - a += S*m_gw[k].get_smu(); - A << A+S; - B += ((m_gw[k].get_smu()-m_gw0.get_mu()).t()*S*(m_gw[k].get_smu()-m_gw0.get_mu())).AsScalar(); - } - S << A+m_N0.i(); - A << (A+m_S0.i()).i(); - a = A*(a+m_S0.i()*m_m0); - - Nu0 = wishrnd(S.i(),(m_k+1)*m_gw0.get_dof()); - mu0 = mvnrnd(a.t(),A).t(); - - m_gw0.set_Nu(Nu0); - m_gw0.set_mu(mu0); - - Gamma G(1+m_k*m_d/2); //G.Set(rand()/float(RAND_MAX)); - m_gw0.set_kappa(G.Next()*2/(1+B)); - //m_gw0.set_kappa(1.0); - -} - -// sample from alpha using additional variable eta -void DPM_GibbsSampler::sample_alpha(){ - float eta,prop; - float ak=m_a0+m_k-1,bn; - - Gamma G1(ak+1); //G1.Set(rand()/float(RAND_MAX)); - Gamma G2(ak); //G2.Set(rand()/float(RAND_MAX)); - Gamma B1(m_alpha+1); //B1.Set(rand()/float(RAND_MAX)); - Gamma B2(m_n); //B2.Set(rand()/float(RAND_MAX)); - - eta = B1.Next(); - eta /= (eta+B2.Next()); - bn = m_b0-std::log(eta); - - prop=ak/(ak+m_n*bn); - m_alpha=(prop*G1.Next()+(1-prop)*G2.Next())/bn; - //m_alpha=.00000001; -} - -double DPM_GibbsSampler::marglik(const ColumnVector& data,const int k){ - double res=0.0; - LogAndSign ld=(2*M_PI*m_gw[k].get_ssigma()).LogDeterminant(); - - res -= 0.5*(ld.LogValue() - +((data-m_gw[k].get_smu()).t() - *m_gw[k].get_ssigma().i() - *(data-m_gw[k].get_smu())).AsScalar()); - - return res; -} -double DPM_GibbsSampler::margint(const ColumnVector& data){ - LogAndSign ld; - double res=m_margintbase; - - ld = m_gw0.get_Nu().LogDeterminant(); - res += ld.LogValue()*m_gw0.get_dof()/2; - - SymmetricMatrix A(m_d); - A << m_gw0.get_Nu()+m_gw0.get_kappa()/(1+m_gw0.get_kappa())*(data-m_gw0.get_mu())*(data-m_gw0.get_mu()).t(); - ld = A.LogDeterminant(); - res -= ld.LogValue()*(m_gw0.get_dof()+1)/2; - - return res; -} - - -// utils -void DPM_GibbsSampler::do_kmeans(){ - int numiter = 100; - - Matrix means(m_d,m_k),newmeans(m_d,m_k); - ColumnVector nmeans(m_k); - - means=0; - nmeans=0; - - // cout<<"inside kmeans"<<endl; - // initialise random - vector< pair<float,int> > rindex(m_n); - randomise(rindex); - vector<pair<float,int> >::iterator riter; - int nn=0,cl=1,nperclass=(int)(float(m_n)/float(m_k)); - for(riter=rindex.begin();riter!=rindex.end();++riter){ - means.Column(cl) += m_data.Row((*riter).second+1).t(); - nmeans(cl) += 1; - m_z[(*riter).second]=cl; - nn++; - if(nn>=nperclass && cl<m_k){ - nn=0; - cl++; - } - } - for(int m=1;m<=m_k;m++) - means.Column(m) /= nmeans(m); - - //cout<<"kmeans init"<<endl; - //for(int i=0;i<n;i++) - //cout<<z[i]<<" "; - //cout<<endl; - - // iterate - for(int iter=0;iter<numiter;iter++){ - // loop over datapoints and attribute z for closest mean - newmeans=0; - nmeans=0; - for(int i=1;i<=m_n;i++){ - float mindist=1E20,dist=0; - int mm=1; - for(int m=1;m<=m_k;m++){ - dist = (means.Column(m)-m_data.Row(i).t()).SumSquare(); - if( dist<mindist){ - mindist=dist; - mm = m; - } - } - m_z[i] = mm; - newmeans.Column(mm) += m_data.Row(i).t(); - nmeans(mm) += 1; - } - - // compute means - for(int m=1;m<=m_k;m++){ - if(nmeans(m)==0){ - if(opts.numclass.value()<0) m_k--; - do_kmeans(); - return; - } - newmeans.Column(m) /= nmeans(m); - } - means = newmeans; - } - - - //cout<<"kmeans end"<<endl; - //for(int i=0;i<n;i++) - //cout<<z[i]<<" "; - //cout<<endl; -} - -ReturnMatrix DPM_GibbsSampler::get_dataindex(){ - ColumnVector index(m_n); - for(unsigned int i=0;i<m_z.size();i++) - index(i+1) = m_z[i]; - index.Release(); - return index; -} -ReturnMatrix DPM_GibbsSampler::get_mldataindex(){ - ColumnVector index(m_n); - double lik,tmplik; - for(int i=1;i<=m_n;i++){ - lik=0.0;tmplik=0;index(i) = 0; - for(int k=0;k<m_k;k++){ - tmplik = m_classnd[k]*marglik(m_data.Row(i).t(),k); - if(tmplik>lik && m_classnd[k]>3){ - lik = tmplik; - index(i) = k+1; - } - } - } - index.Release(); - return index; -} diff --git a/dpm_gibbs.h b/dpm_gibbs.h deleted file mode 100644 index 3cc04d1..0000000 --- a/dpm_gibbs.h +++ /dev/null @@ -1,318 +0,0 @@ -#if !defined(_DPM_GIBBS_H) -#define _DPM_GIBBS_H - -#include "gibbs.h" -#include "dpmOptions.h" -#include "newran/newran.h" -#include "miscmaths/miscprob.h" -#include <stdlib.h> -#include <stdio.h> -#include <cmath> - - -using namespace NEWMAT; -using namespace NEWRAN; -using namespace MISCMATHS; -using namespace DPM; -using namespace std; - - -// Gaussian-InverWishart distribution -// p(mu,sigma)=det(sigma)^(-(nu+d)/2-1)exp(-trace(Nu*inv(sigma))/2 -kappa/2*(mu-m_mu)'inv(sigma)(mu-m_mu)) -class GaussianWishart{ - private: - friend std::ostream& operator << (ostream& o,GaussianWishart& g); - protected: - ColumnVector m_mu; - SymmetricMatrix m_Nu; - float m_kappa; - int m_dof; - int m_dim; - - ColumnVector m_smu; // sample mean - SymmetricMatrix m_ssigma; // sample covariance - - - public: - GaussianWishart(){} - GaussianWishart(const int dim):m_dim(dim){ - m_mu.ReSize(m_dim); - m_Nu.ReSize(m_dim); - } - GaussianWishart(const ColumnVector& mu,const SymmetricMatrix& Nu,const int dof,const float& kappa): - m_mu(mu),m_Nu(Nu),m_kappa(kappa),m_dof(dof){ - m_dim=m_mu.Nrows(); - sample(); - } - ~GaussianWishart(){} - inline ColumnVector get_mu()const{return m_mu;} - void set_mu(const ColumnVector& mu){m_mu=mu;} - inline SymmetricMatrix get_Nu()const{return m_Nu;} - void set_Nu(const SymmetricMatrix& Nu){m_Nu=Nu;} - void set_kappa(const float& kappa){m_kappa=kappa;} - inline float get_kappa()const{return m_kappa;} - inline int get_dof()const{return m_dof;} - - void postupdate(const vector<ColumnVector>& data,const GaussianWishart& gw0){ - ColumnVector mdat(m_dim); - SymmetricMatrix S(m_dim),SS(m_dim); - - float n = (float)data.size(); - m_dof = gw0.get_dof() + int(n); - m_kappa = gw0.get_kappa() + n; - mdat=0;S=0,SS=0; - for(int i=0;i<int(n);i++){ - SS << data[i]*data[i].t(); - S += SS; - mdat += data[i]; - } - mdat /= n; - - SS << S -n*mdat*mdat.t(); - SS << SS + gw0.get_kappa()*n/m_kappa * (mdat-gw0.get_mu())*(mdat-gw0.get_mu()).t(); - - m_mu = ( gw0.get_kappa()*gw0.get_mu() + n*mdat )/m_kappa; - m_Nu << gw0.get_Nu() + SS; - - sample(); - } - void postupdate(const Matrix& data,const GaussianWishart& gw0){ - ColumnVector mdat(m_dim); - SymmetricMatrix S(m_dim),SS(m_dim); - - float n = (float)data.Nrows(); - m_dof = gw0.get_dof() + int(n); - m_kappa = gw0.get_kappa() + n; - mdat=0;S=0,SS=0; - for(int i=1;i<=int(n);i++){ - SS << data.Row(i).t()*data.Row(i); - S += SS; - mdat += data.Row(i).t(); - } - mdat /= n; - - SS << S -n*mdat*mdat.t(); - SS << SS + gw0.get_kappa()*n/m_kappa * (mdat-gw0.get_mu())*(mdat-gw0.get_mu()).t(); - - m_mu = ( gw0.get_kappa()*gw0.get_mu() + n*mdat )/m_kappa; - m_Nu << gw0.get_Nu() + SS; - - sample(); - } - void sample(ColumnVector& mu,SymmetricMatrix& sigma){ - sigma = iwishrnd(m_Nu.i(),m_dof); - mu = mvnrnd(m_mu.t(),sigma/m_kappa).t(); - } - void sample(){ - m_ssigma = iwishrnd(m_Nu.i(),m_dof); - m_smu = mvnrnd(m_mu.t(),m_ssigma/m_kappa).t(); - } - void print(ostream& os)const{ - os << "Gaussian-InverseWishart distribution" << endl; - os << "mean : " << m_mu.t(); - os << "variance : " << m_Nu.Row(1); - for(int i=2;i<=m_dim;i++) - os << " "<<m_Nu.Row(i); - os << "dof : "<<m_dof<<endl; - os << "kappa : "<<m_kappa<<endl; - os << "sample mu : "<<m_smu.t(); - os << "sample var : "<<m_ssigma.Row(1); - for(int i=2;i<=m_dim;i++) - os << " "<<m_ssigma.Row(i); - os << "-----------------------------------"<<endl; - } - ColumnVector get_smu()const{return m_smu;} - SymmetricMatrix get_ssigma()const{return m_ssigma;} - GaussianWishart& operator=(const GaussianWishart& rhs){ - m_mu = rhs.m_mu; - m_Nu = rhs.m_Nu; - m_kappa = rhs.m_kappa; - m_dof = rhs.m_dof; - m_dim = rhs.m_dim; - - m_smu = rhs.m_smu; - m_ssigma = rhs.m_ssigma; - - return *this; - } - -}; - -//bool compare(const pair<int,float> &p1,const pair<int,float> &p2){ -//return (p1.second < p2.second) ? true : false; -//} - - -class DPM_GibbsSampler : public GibbsSampler -{ - private: - friend std::ostream& operator << (ostream& o,DPM_GibbsSampler& g); - - protected: - DPM::dpmOptions& opts; - - // parameters ------> estimated via gibb's sampling - float m_alpha; - vector<GaussianWishart> m_gw; - vector<int> m_z; - // hyperparameters ------> estimated via gibb's sampling - GaussianWishart m_gw0; - // hyper-hyperparameters ------> these are the only fixed parameters - float m_a0; // a0 = 1 - float m_b0; // b0 = 1 - ColumnVector m_m0; // m0 = mean(data) - SymmetricMatrix m_S0; // S0 = cov(data) - SymmetricMatrix m_N0; // inv(cov(data))/(nu0-d-1)^2 - int m_n0; - // data-related quantities - vector<int> m_classnd; - int m_k; - double m_margintbase; - - vector< pair<float,int> > randindex; - - // samples - vector<float> m_sample_alpha; - vector<int> m_sample_k; - vector<double> m_sample_likelihood; - double m_likelihood; - int m_nsamples; - vector<float> m_mean_z; - - const Matrix& m_data; - -public: - DPM_GibbsSampler(const Matrix& data,int numiter,int burnin,int sampleevery): - GibbsSampler(numiter,burnin,sampleevery), - opts(DPM::dpmOptions::getInstance()),m_data(data){ - m_n = m_data.Nrows(); - m_d = m_data.Ncols(); - - m_nsamples = (int)floor( (numiter - burnin) / sampleevery ); - - m_sample_alpha.resize(m_nsamples); - m_sample_k.resize(m_nsamples); - m_sample_likelihood.resize(m_nsamples); - m_mean_z.resize(m_n); - - Random::Set(rand() / float(RAND_MAX)); - } - ~DPM_GibbsSampler(){} - - // parent class function definitions - void sample_parameters(); - void sample_hyperparameters(); - - // initialisation functions - void init(); - void init_oneperdata(); - void init_onebigclass(); - void init_kmeans(const int k=10); - void init_random(const int k=10); - - // sample model parameters - void sample_z(); - void sample_gw(); - void sample_gw0(); - void sample_alpha(); - - // utils - double marglik(const ColumnVector&,const int); - double margint(const ColumnVector&); - void do_kmeans(); - ReturnMatrix get_dataindex(); - ReturnMatrix get_mldataindex(); - - int get_numclass()const{return m_k;} - - // io - void print(ostream& os){ - os << "-------fixed parameters-------"<<endl; - os << "a0 = "<<m_a0<<endl; - os << "b0 = "<<m_b0<<endl; - os << "nu0 = "<<m_gw0.get_dof()<<endl; - os << "m0 = "<<m_m0.t(); - os << "S0 = "<<m_S0.Row(1); - for(int i=2;i<=m_S0.Ncols();i++) - os << " "<<m_S0.Row(i); - os << "N0 = "<<m_N0.Row(1); - for(int i=2;i<=m_N0.Ncols();i++) - os << " "<<m_N0.Row(i); - os << "-------hyper-parameters-------"<<endl; - os << "k = "<<m_k<<endl; - os << "alpha = "<<m_alpha<<endl; - os << "mu0 = "<<m_gw0.get_mu().t(); - os << "Nu0 = "<<m_gw0.get_Nu().Row(1); - for(int i=2;i<=m_N0.Ncols();i++) - os << " "<<m_gw0.get_Nu().Row(i); - os << "kappa0 = "<<m_gw0.get_kappa()<<endl; - os << "-------class-parameters-------"<<endl; - for(int i=0;i<m_k;i++){ - //os << "cluster "<<i<<endl; - //os << "n\t=\t"<<m_classnd[i]<<endl; - os <<m_classnd[i]<<" "; - //os << m_gw[i]; - //os << endl; - } - os << endl; - } - void save(){ - string logsamples = opts.logfile.value() + ".samples"; - string logmeans = opts.logfile.value() + ".means"; - string logvariances = opts.logfile.value() + ".variances"; - string zzz = opts.logfile.value() + ".z"; - - ofstream of_s(logsamples.c_str()); - ofstream of_m(logmeans.c_str()); - ofstream of_v(logvariances.c_str()); - ofstream of_z(zzz.c_str()); - - double evidence=0; - double maxlog=0; - - of_s << "k\talpha\tlik\n"; - for (unsigned int i=0;i<m_sample_likelihood.size();i++){ - //OUT(i); - of_s << m_sample_k[i] << "\t" - << m_sample_alpha[i] << "\t" - << m_sample_likelihood[i] << "\n"; - if(m_sample_likelihood[i]>maxlog) - maxlog=m_sample_likelihood[i]; - } - // compute evidence - for(unsigned int i=0;i<m_sample_likelihood.size();i++){ - evidence += std::exp(m_sample_likelihood[i]-maxlog); - } - - // store means and variances - for(int k=0;k<m_k;k++){ - of_m << m_gw[k].get_smu().t(); - of_v << m_gw[k].get_ssigma().t(); - } - - evidence = -log((float)m_sample_likelihood.size()) + maxlog + log(evidence); - cout<<m_k<<" "; - cout<<evidence<<endl; - - ColumnVector mlz(m_n); - mlz = get_mldataindex(); - of_z << mlz; - - cout<<"final k="<< mlz.MaximumAbsoluteValue()<<endl; - - } - void record(const int samp){ - //cout<<"record sample "<<samp<<endl; - m_sample_likelihood[samp] = m_likelihood; - m_sample_k[samp] = m_k; - m_sample_alpha[samp] = m_alpha; - for(int i=0;i<m_n;i++) - m_mean_z[i] += m_z[i]; - } - - - -}; - - -#endif diff --git a/gibbs.h b/gibbs.h deleted file mode 100644 index 9cb08a0..0000000 --- a/gibbs.h +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (C) 2007 University of Oxford */ - -/* CCOPYRIGHT */ - -#if !defined(_GIBBS_H) -#define _GIBBS_H - -#include <stdlib.h> -#include <stdio.h> -#include <iostream.h> - -using namespace std; - -class GibbsSampler -{ - protected: - - int m_numiter; - int m_burnin; - int m_sampleevery; - int m_n; - int m_d; - - public: - GibbsSampler(int numiter,int burnin,int sampleevery): - m_numiter(numiter),m_burnin(burnin),m_sampleevery(sampleevery){} - virtual ~GibbsSampler(){} - - virtual void init() = 0 ; - virtual void record(const int) = 0; - virtual void sample_parameters() = 0; - virtual void sample_hyperparameters() = 0; - - void run(){ - - int recordcount=0; - - // burnin period (no sampling) - cout<<"burnin"<<endl; - for(int i=0;i<m_burnin;i++){ - //cout<<"-----------"<<endl; - sample_parameters(); - sample_hyperparameters(); - - } - - // m_numiter=2; - - // after burnin, sample ervery "sampleevery" - cout<<"gibbs"<<endl; - int samp = 0; - for(int i=m_burnin;i<m_numiter;i++){ - sample_parameters(); - sample_hyperparameters(); - - //print(); - - recordcount++; - - if(recordcount==m_sampleevery){ - //cout<<"record"<<endl; - record(samp);samp++; - recordcount=0; - } - } - - - } - -}; - - - -#endif -- GitLab