diff --git a/Makefile b/Makefile
index b298aec232f8e6d5068626d753957195059db325..8ffd5c643ea05e979d9c3b4ecf78a147a5a762ad 100644
--- a/Makefile
+++ b/Makefile
@@ -29,9 +29,9 @@ TEST=testfile
 ORDVEC=reorder_dyadic_vectors
 DPM=dpm
 
-
+DPMOBJS=dpm.o dpm_gibbs.o dpmOptions.o
 DTIFITOBJS=dtifit.o dtifitOptions.o
-CCOPSOBJS=ccops.o ccopsOptions.o
+CCOPSOBJS=ccops.o ccopsOptions.o dpm_gibbs.o dpmOptions.o
 PTOBJS=probtrack.o probtrackOptions.o pt_alltracts.o pt_matrix.o pt_seeds_to_targets.o pt_simple.o pt_twomasks.o pt_matrix_mesh.o
 PTXOBJS=probtrackx.o probtrackxOptions.o streamlines.o ptx_simple.o ptx_seedmask.o ptx_twomasks.o ptx_nmasks.o
 FTBOBJS=find_the_biggest.o
@@ -47,7 +47,6 @@ FMOOBJS=fdt_matrix_ops.o
 INDEXEROBJS=indexer.o
 TESTOBJS=testfile.o
 ORDVECOBJS=reorder_dyadic_vectors.o heap.o
-DPMOBJS=dpm.o dpmOptions.o
 
 SGEBEDPOST =sge_bedpost  sge_bedpost_postproc.sh  sge_bedpost_preproc.sh  sge_bedpost_single_slice.sh
 SGEBEDPOSTX=sge_bedpostX sge_bedpostX_postproc.sh sge_bedpostX_preproc.sh sge_bedpostX_single_slice.sh
@@ -56,7 +55,7 @@ SCRIPTS = eddy_correct bedpost bedpost_proc bedpost_cleanup bedpost_kill_all bed
 FSCRIPTS=correct_and_average ocmr_preproc bedpostX bedpostX_proc bedpostX_cleanup bedpostX_kill_all \
 	${SGEBEDPOST} ${SGEBEDPOSTX}
 
-XFILES = dtifit ccops probtrack find_the_biggest medianfilter diff_pvm make_dyadic_vectors proj_thresh
+XFILES = dpm dtifit ccops probtrack find_the_biggest medianfilter diff_pvm make_dyadic_vectors proj_thresh
 FXFILES = reord_OM sausages replacevols fdt_matrix_ops probtrackx xfibres indexer
 
 
diff --git a/ccops.cc b/ccops.cc
index f9f98327b85a34f2b736f638310df4efd4bd0c24..349f89feee477263216cbf58cd84f83bec5c20f8 100644
--- a/ccops.cc
+++ b/ccops.cc
@@ -9,11 +9,13 @@
 #include "ccopsOptions.h"
 #include <vector>
 #include <algorithm>
+#include "dpm_gibbs.h"
 
 using namespace std;
 using namespace NEWIMAGE;
 using namespace NEWMAT;
 using namespace CCOPS;
+
   void spect_reord(SymmetricMatrix& A,ColumnVector& r,ColumnVector& y){
     SymmetricMatrix Q=-A;
       DiagonalMatrix t(Q.Nrows());
@@ -70,6 +72,117 @@ using namespace CCOPS;
       
   } 
 
+// calculate pca decomposition using the covariance method
+// X is the data dxn matrix with n observations (data points) and d variables (dimensions) 
+ReturnMatrix pcacov(Matrix& X,const float& perc){
+  int n=X.Nrows();
+  int d=X.Ncols();
+  // de-mean
+  cout<<"de-mean"<<endl;
+  ColumnVector Xmean(d);
+  for(int j=1;j<=d;j++)
+    Xmean(j) = X.Column(j).Sum()/n;
+  for(int j=1;j<=d;j++){
+    for(int i=1;i<=n;i++)
+      X(i,j) -= Xmean(j);
+  }
+
+  // calculate covariance
+  cout<<"covariance"<<endl;
+  SymmetricMatrix C(d);
+
+  if(d<n)
+    C << (X.t()*X)/n;
+  else
+    C << X*X.t()/n;
+
+  // eigenvalues
+  cout<<"eigenvalues"<<endl;
+  Matrix V;
+  DiagonalMatrix D;
+  EigenValues(C,D,V);
+
+  // select subset
+  cout<<"subset"<<endl;
+  float cumsum=0,total=D.Trace();
+  int dim=0;
+  for(int i=D.Nrows();i>=1;i--){
+    cumsum += D(i);
+    if(cumsum/total < perc){dim++;}
+    else{break;}
+  }
+  if(dim<=2)dim=2;
+
+  dim=5;
+
+  Matrix v(V.Nrows(),dim);
+  ColumnVector lam(dim);
+  for(int j=1;j<=dim;j++){
+    v.Column(j) = V.Column(V.Ncols()-j+1);
+    lam(j)=D(V.Ncols()-j+1);
+  }
+
+  //cout<<"zscores"<<endl;
+  // convert to z-scores
+  //for(int i=1;i<=d;i++)
+  //X.Row(i) /= sqrt(C(i,i));
+  // reconstruct data
+  cout<<"data"<<endl;
+  Matrix data(dim,n);
+
+  if(!(d<n)){
+    v = (X.t() * v);
+    for(int j=1;j<=dim;j++)
+      v.Column(j) /= sqrt(n*lam(j));
+  }
+  data = X*v;
+  
+
+  data.Release();
+  return data;
+}
+
+void dpm_reord(Matrix& A,ColumnVector& r,ColumnVector& y){
+  cout << "start dpm reordering" <<endl;
+  // pca preprocess
+  cout << "pca preprocessing" <<endl;
+  float perc=.90;
+  Matrix data;
+  OUT(A.Nrows());
+  OUT(A.Ncols());
+  write_ascii_matrix(A,"nonpreprocessed_data");
+  OUT(A.Nrows());
+  OUT(A.Ncols());
+  data = pcacov(A,perc);
+  write_ascii_matrix(data,"preprocessed_data");
+
+  // dpm
+  cout << "dpm clustering" <<endl;
+  int numiter=2000;
+  int burnin=1000;
+  int sampleevery=1;
+  DPM_GibbsSampler gs(data,numiter,burnin,sampleevery);
+  gs.init();
+  gs.run();
+  // save
+  int n = data.Nrows();
+  vector< pair<float,int> > myvec;
+  ColumnVector z(n);
+  z = gs.get_dataindex();
+  //z = gs.get_mldataindex();
+  for(int i=1;i<=n;i++){
+    pair<float,int> mypair;
+    mypair.first  = z(i);
+    mypair.second = i;
+    myvec.push_back(mypair);
+  }
+  sort(myvec.begin(),myvec.end());
+  r.ReSize(n);y.ReSize(n);
+  for(int i=1;i<=n;i++){
+    y(i)=myvec[i-1].first;
+    r(i)=myvec[i-1].second;
+  }
+}
 
 void rem_zrowcol(const Matrix& myOMmat,const Matrix& coordmat,const Matrix& tractcoordmat,const bool coordbool,const bool tractcoordbool,Matrix& newOMmat,Matrix& newcoordmat, Matrix& newtractcoordmat)
 {
@@ -304,6 +417,8 @@ int main ( int argc, char **argv ){
   string ip=opts.inmatrix.value();
   make_basename(ip);
   
+  srand(time(NULL));
+
   ColumnVector y1,r1,y2,r2;
   volume<int> myOM;
   volume<int> coordvol;
@@ -410,9 +525,9 @@ int main ( int argc, char **argv ){
 
  
   rem_zrowcol(myOMmat,mycoordmat,mytractcoordmat,coordbool,tractcoordbool,newOMmat,newcoordmat,newtractcoordmat);
-  //   cerr<<"NOW"<<endl;
-  //   cerr<<myOMmat.MaximumAbsoluteValue()<<endl;
-  //   cerr<<newOMmat.MaximumAbsoluteValue()<<endl;
+  //cerr<<"NOW"<<endl;
+  //cerr<<myOMmat.MaximumAbsoluteValue()<<endl;
+  //cerr<<newOMmat.MaximumAbsoluteValue()<<endl;
  
   //write_ascii_matrix("ncm",newcoordmat);
   // write_ascii_matrix("nctm",newtractcoordmat);
@@ -427,12 +542,12 @@ int main ( int argc, char **argv ){
 
 
   cerr<<"Computing correlation"<<endl;
-  SymmetricMatrix CtCt;
-  CtCt << corrcoef(newOMmat.t());
-  CtCt << CtCt+1;
+  SymmetricMatrix CtCt(newOMmat.Nrows());
+  //CtCt << corrcoef(newOMmat.t());
+  //CtCt << CtCt+1;
 
   // adding connexity constraint
-  if(!coordbool){
+  if(opts.connexity.value()!=0 && !coordbool){
     cerr<<"WARNING !! No coordinates provided. I cannot apply any connexity constraint."<<endl;
   }
   else{
@@ -467,7 +582,29 @@ int main ( int argc, char **argv ){
   }
   else{
     cerr<<"Starting First Reordering"<<endl;
-    spect_reord(CtCt,r1,y1);
+    if(opts.scheme.value()=="dpm"){
+      OUT(myOMmat.Ncols());
+      OUT(myOMmat.Nrows());
+      OUT(newOMmat.Ncols());
+      OUT(newOMmat.Nrows());
+      dpm_reord(newOMmat,r1,y1);
+      //Matrix A;
+      //A=CtCt;
+      //dpm_reord(A,r1,y1);
+    }
+    else if(opts.scheme.value()=="kmeans"){
+      int kk=opts.kmeans.value();
+      if(kk<=1){
+	cerr << "Error using kmeans. Number of clusters must be >=2" << endl;
+	return -1;
+      }
+      cout << "...  using kmeans" << endl;
+      //kmeans_reord(CtCt,r1,y1,kk);
+    }
+    else 
+      spect_reord(CtCt,r1,y1);
+    
+
    
    
     cerr<<"Permuting seed CC matrix"<<endl;
@@ -491,7 +628,18 @@ int main ( int argc, char **argv ){
     write_ascii_matrix(y1,base+"y1");
     save_volume(outCCvol,"reord_CC_"+base);
     save_volume(outcoords,"coords_for_reord_"+base);
- 
+    if(opts.scheme.value()=="dpm" || opts.scheme.value()=="kmeans"){
+      volume<int> mask;
+      read_volume(mask,opts.mask.value());
+      mask = 0;
+      for(int i=0;i<outcoords.xsize();i++){
+	mask(outcoords(i,0,0),
+	     outcoords(i,1,0),
+	     outcoords(i,2,0)) = (int)y1(i+1) + 1;
+      }
+      save_volume(mask,"reord_mask_"+base);
+    }
+
   }
 
   if(opts.reord2.value()){
diff --git a/ccopsOptions.h b/ccopsOptions.h
index efb79429cb0245f36ebff08430a67ea7570138c2..24478b1f1233d309bb4f1ad23fe7fb667114b900 100644
--- a/ccopsOptions.h
+++ b/ccopsOptions.h
@@ -28,11 +28,14 @@ class ccopsOptions {
   Option<string> inmatrix;
   Option<string> basename;
   Option<string> excl_mask;
-  Option<bool> reord1;
-  Option<bool> reord2;
+  Option<bool>  reord1;
+  Option<bool>  reord2;
   Option<float> connexity;
-  Option<int> bin;
+  Option<int>   bin;
   Option<float> power;
+  Option<string> mask;
+  Option<string> scheme;
+  Option<int>    kmeans;
   bool parse_command_line(int argc, char** argv);
   
  private:
@@ -81,6 +84,15 @@ class ccopsOptions {
    power(string("-p,--power"), 1, 
 	 string("power to raise the correlation matrix to (default 1)"), 
 	 false, requires_argument),
+   mask(string("-m,--mask"), "", 
+	 string("brain mask used to output the clustered roi mask"), 
+	 false, requires_argument),
+   scheme(string("-s,--scheme"), "spectral", 
+	 string("Reordering algorithm. Can be either spectral (default) or dpm or kmeans"), 
+	 false, requires_argument),
+   kmeans(string("-K"), 2, 
+	  string("Number of clusters to be used in kmeans"), 
+	  false, requires_argument),
    options("ccops","")
    {
      
@@ -95,6 +107,9 @@ class ccopsOptions {
        options.add(connexity);
        options.add(bin);
        options.add(power);
+       options.add(mask);
+       options.add(scheme);
+       options.add(kmeans);
        
      }
      catch(X_OptionError& e) {
diff --git a/dpm.cc b/dpm.cc
index c6d94e9351ad1a35c630c4e8ff5f93be58d6335e..815b971a6cad343f87a21834a605fd09a5c302b0 100644
--- a/dpm.cc
+++ b/dpm.cc
@@ -4,15 +4,13 @@
 
 /*  CCOPYRIGHT  */
 
-#include <iostream>
-#include <cmath>
-
+#include <stdio.h>
 #include "dpm_gibbs.h"
 
-using namespace NEWMAT;
 using namespace DPM;
 using namespace Utilities;
 
+
 int main (int argc, char *argv[]){
 
   Log& logger = LogSingleton::getInstance();
diff --git a/dpmOptions.h b/dpmOptions.h
index d2f9b8c5944220e4b8f43a19a691ef99e3244ab8..d069c66d078643d17f47339585347253a61a4c9f 100644
--- a/dpmOptions.h
+++ b/dpmOptions.h
@@ -19,6 +19,8 @@ class dpmOptions {
   ~dpmOptions() { delete gopt; }
 
   Option<bool>   help;
+  Option<bool>   verbose;
+
   Option<string> datafile;
   Option<string> logfile;
   Option<string> init_class;
@@ -52,20 +54,23 @@ class dpmOptions {
    help(string("-h,--help"), false,
 	string("display this message"),
 	false,no_argument),
+   verbose(string("-V,--verbose"), false,
+	string("display program outputs"),
+	false,no_argument),
    datafile(string("-d,--data"), string(""),
 	    string("data file"),
 	    true,requires_argument),
    logfile(string("-o,--out"), string(""),
 	    string("output file"),
 	    true, requires_argument),
-   init_class(string("--ic,--initclass"), "oneperdata",
+   init_class(string("--ic,--initclass"), "random",
 	    string("data labelling initialisation"),
 	    false, requires_argument),
    numclass(string("-k,--numclass"),-1,
 	    string("fix number of classes - default=infinite"),
 	    false,requires_argument),
    numiter(string("--ni,--numiter"),2000,
-	    string("number of iterations - default=2000"),
+	   string("number of iterations - default=2000"),
 	   false,requires_argument),
    burnin(string("--bi,--burnin"),1000,
 	  string("number of iterations before sampling - default=1000"),
@@ -74,11 +79,12 @@ class dpmOptions {
 	       string("sampling frequency - default=1"),
 	       false,requires_argument),
    options("dpm","dpm -d data -o logfile")
-   {
+     {
      
     
      try {
        options.add(help);
+       options.add(verbose);
        options.add(datafile);
        options.add(logfile);
        options.add(init_class);
diff --git a/dpm_gibbs.cc b/dpm_gibbs.cc
new file mode 100644
index 0000000000000000000000000000000000000000..30f205e4dbd61cd67e920ac04b55561b0c1f44ff
--- /dev/null
+++ b/dpm_gibbs.cc
@@ -0,0 +1,430 @@
+#include "dpm_gibbs.h"
+
+bool compare(const pair<float,int> &r1,const pair<float,int> &r2){
+  return (r1.first<r2.first);
+}
+
+void randomise(vector< pair<float,int> >& r){
+  for(unsigned int i=0;i<r.size();i++){
+    pair<float,int> p(rand()/float(RAND_MAX),i);
+    r[i]=p;
+  }
+  sort(r.begin(),r.end(),compare);
+  
+}
+std::ostream& operator << (ostream& o,DPM_GibbsSampler& g){
+  g.print(o);
+  return o;
+}
+std::ostream& operator << (ostream& o,GaussianWishart& g){
+  g.print(o);
+  return o;
+}
+
+
+void DPM_GibbsSampler::init(){
+  // fix fixed parameters
+  m_a0       = 1.0;
+  m_b0       = 1.0E8;
+  m_S0       << 1000.0*Identity(m_d);//cov(m_data);
+  m_N0       << Identity(m_d);///(m_nu0-m_d-1);
+  m_m0       = mean(m_data,1).t();
+  m_n0       = 1; 
+
+  // initialise all other parameters
+  m_alpha    = 1.0;
+  m_k        = opts.numclass.value();
+
+  // class hyper parameters
+  float kappa0   = 1.0;
+  int nu0        = m_d;
+  SymmetricMatrix Nu0(m_d);
+  Nu0 << m_n0*m_N0;//.01*m_d*Identity(m_d);//cov(m_data);//*(m_nu0-m_d-1);
+  ColumnVector mu0(m_d); 
+  mu0 = m_m0;
+  
+  m_gw0      = GaussianWishart(mu0,Nu0,nu0,kappa0);
+  
+  // class parameters
+  if(opts.numclass.value() < 0){ // infinite mixture case
+    if(opts.init_class.value() == "oneperdata"){
+      if(opts.verbose.value())
+	cout << "Initialise with one class per data"<<endl;
+      init_oneperdata();
+    }
+    else if (opts.init_class.value() == "one"){
+      if(opts.verbose.value())
+	cout << "initialise with one big class"<<endl;
+      init_onebigclass();
+    }
+    else if (opts.init_class.value() == "kmeans"){
+      if(opts.verbose.value())
+	cout << "Initialise using kmeans" << endl;
+      init_kmeans();
+    }
+    else{ // random
+      cout << "Random initialisation using 10 classes" << endl;
+      init_random();
+    }
+  }
+  else{ // finite mixture case
+    init_kmeans(opts.numclass.value());
+  }
+
+  // calculate part of the marginalisation over class mean/variance
+  // this part doesn't change through the iterations
+  m_margintbase = m_d/2*(log(m_gw0.get_kappa()/(1+m_gw0.get_kappa()))-log(M_PI)) 
+    + lgam(float(nu0+1)/2.0) -lgam(float(nu0+1-m_d)/2.0);
+    
+  // randomised index for loop over data items
+  randindex.resize(m_n);
+  
+  //cout << *this;
+}
+// different initialisation schemes
+void DPM_GibbsSampler::init_oneperdata(){
+  m_k = m_n;
+  // set parameters
+  for(int i=1;i<=m_n;i++){
+    GaussianWishart gw(m_d);
+    gw.postupdate(m_data.SubMatrix(i,i,1,m_d),m_gw0);
+    m_gw.push_back(gw);
+    m_z.push_back(i-1);
+    m_classnd.push_back(1);
+  }
+}
+void DPM_GibbsSampler::init_onebigclass(){
+  m_k = 1;
+  GaussianWishart gw(m_d);
+  gw.postupdate(m_data,m_gw0);
+  m_gw.push_back(gw);
+  for(int i=0;i<m_data.Nrows();i++)m_z.push_back(0);
+  m_classnd.push_back(m_data.Nrows());
+}
+void DPM_GibbsSampler::init_kmeans(const int k){
+  m_k=k;
+  m_z.resize(m_n);
+  do_kmeans();
+  for(int k=1;k<=m_k;k++){
+    GaussianWishart gw(m_d);
+    vector<ColumnVector> dat;
+    for(int i=1;i<=m_n;i++)
+      if(m_z[i-1] == k){
+	dat.push_back(m_data.Row(i).t());
+	m_z[i-1] -- ;
+      }
+    gw.postupdate(dat,m_gw0);
+    m_gw.push_back(gw);
+    m_classnd.push_back((int)dat.size());
+  }
+}
+void DPM_GibbsSampler::init_random(const int k){
+  m_k=k;
+  m_z.resize(m_n);
+  vector< pair<float,int> > rindex(m_n);
+  randomise(rindex);
+  vector<pair<float,int> >::iterator riter;
+  int nn=0,cl=1,nperclass=(int)(float(m_n)/float(m_k));
+  for(riter=rindex.begin();riter!=rindex.end();++riter){
+    m_z[(*riter).second]=cl;
+    nn++;
+    if(nn>=nperclass && cl<m_k){
+      nn=0;
+      cl++;
+    }
+  }
+  for(int k=1;k<=m_k;k++){
+    GaussianWishart gw(m_d);
+    vector<ColumnVector> dat;
+    for(int i=1;i<=m_n;i++)
+      if(m_z[i-1] == k){
+	dat.push_back(m_data.Row(i).t());
+	m_z[i-1] -- ;
+      }
+    gw.postupdate(dat,m_gw0);
+    m_gw.push_back(gw);
+    m_classnd.push_back((int)dat.size());
+  }    
+}
+  
+
+void DPM_GibbsSampler::sample_parameters(){
+  cout << *this;
+
+  // sample indicators
+  //cout<<"sample z"<<endl;
+  sample_z();
+  // sample mean and variance of each class
+  //cout<<"sample gw"<<endl;
+  sample_gw();
+}
+void DPM_GibbsSampler::sample_hyperparameters(){
+  // sample hyperpriors
+  //cout<<"sample gw0"<<endl;
+  sample_gw0();
+  // sample alpha
+  //cout<<"sample alpha"<<endl;
+  sample_alpha();
+}
+// sample indicator variables
+void DPM_GibbsSampler::sample_z(){
+  ColumnVector datapoint(m_d);
+  randomise(randindex);
+
+  // if finite gaussian mixture, do not add new classes
+  float extra_finite1 = opts.numclass.value() < 0 ? 0.0 : m_alpha/float(m_k);
+  float extra_finite2 = opts.numclass.value() < 0 ? 1.0 : 0.0;
+  
+  vector< pair<float,int> >::iterator iter;
+  for(iter=randindex.begin(); iter!=randindex.end(); ++iter){
+    ColumnVector cumsum(m_k+1);
+    ColumnVector w(m_k+1);
+    
+    datapoint=m_data.Row((*iter).second+1).t();
+    int oldz=m_z[(*iter).second],newz=oldz;
+    m_classnd[oldz] -= 1;
+    
+    // compute class weights
+    double sum=0.0;
+    for(int k=0;k<m_k;k++){
+      w(k+1) = exp(log(m_classnd[k]+extra_finite1)+marglik(datapoint,k));
+      sum += exp(log(m_classnd[k]+extra_finite1)+marglik(datapoint,k));
+      cumsum(k+1) = sum;
+    }
+    w(m_k+1) = m_alpha*exp(margint(datapoint)) * extra_finite2;
+    sum += m_alpha*exp(margint(datapoint)) * extra_finite2;
+    cumsum(m_k+1) = sum;
+    // sample z using the weights
+    float U=rand()/float(RAND_MAX);
+    U *= sum;
+    for(int k=1;k<=m_k+1;k++){
+      if(U<cumsum(k)){
+	newz=k-1;
+	break;
+      }
+    }
+    m_z[(*iter).second] = newz;
+
+    if( newz >= m_k ){ // add a new class
+      m_k++;
+      m_classnd.push_back(1);
+      GaussianWishart gw(m_d);
+      gw.postupdate(datapoint.t(),m_gw0);
+      m_gw.push_back(gw);
+    }
+    else{
+      m_classnd[newz] += 1;
+    }
+    //cout << " chosen cluster: "<<(*iter).second<<",oldz="<<oldz<<",newz="<<newz;
+    //cout << ",w="<<w(newz+1)<<",nold="<<m_classnd[oldz]<<"n="<<m_classnd[newz]<<endl;
+
+  }// end loop over data points
+  //cout<<"end data"<<endl<<endl;
+
+  // delete empty classes if in infinite mode
+  if(opts.numclass.value()<0){
+    for(int k=m_k-1;k>=0;k--){
+      if(m_classnd[k] == 0){
+	for(int i=0;i<m_n;i++)
+	  if(m_z[i]>k)m_z[i]--;
+	for(int kk=k;kk<m_k-1;kk++){
+	  m_classnd[kk]=m_classnd[kk+1];
+	  m_gw[kk]=m_gw[kk+1];
+	}
+	m_classnd.pop_back();
+	m_gw.pop_back();
+	m_k--;
+      }
+    }
+  }
+}
+
+void DPM_GibbsSampler::sample_gw(){
+  // update classes posteriors
+  vector< vector<ColumnVector> > data;
+  data.resize(m_k);    
+  
+  // calculate likelihood
+  m_likelihood = 0;
+  for(int i=0;i<m_n;i++){
+    data[ m_z[i] ].push_back(m_data.Row(i+1).t());
+    m_likelihood += -marglik(m_data.Row(i+1).t(),m_z[i]);
+  }
+
+  for(int k=0;k<m_k;k++){
+    if(data[k].size()>0)
+      m_gw[k].postupdate(data[k],m_gw0);
+  }
+}
+
+
+void DPM_GibbsSampler::sample_gw0(){
+  SymmetricMatrix Nu0(m_d),A(m_d),S(m_d);
+  ColumnVector a(m_d),mu0(m_d);    
+  float B=0;
+
+  A=0;a=0;
+  for(int k=0;k<m_k;k++){
+    S = m_gw[k].get_ssigma().i();
+    a += S*m_gw[k].get_smu();
+    A << A+S;
+    B += ((m_gw[k].get_smu()-m_gw0.get_mu()).t()*S*(m_gw[k].get_smu()-m_gw0.get_mu())).AsScalar();
+  }
+  S << A+m_N0.i();
+  A << (A+m_S0.i()).i();
+  a = A*(a+m_S0.i()*m_m0);
+  
+  Nu0 = wishrnd(S.i(),(m_k+1)*m_gw0.get_dof());
+  mu0 = mvnrnd(a.t(),A).t();
+  
+  m_gw0.set_Nu(Nu0);
+  m_gw0.set_mu(mu0);
+
+  Gamma G(1+m_k*m_d/2); //G.Set(rand()/float(RAND_MAX));
+  m_gw0.set_kappa(G.Next()*2/(1+B));
+  //m_gw0.set_kappa(1.0);
+
+}
+
+// sample from alpha using additional variable eta
+void DPM_GibbsSampler::sample_alpha(){
+  float eta,prop;
+  float ak=m_a0+m_k-1,bn;
+  
+  Gamma G1(ak+1);       //G1.Set(rand()/float(RAND_MAX));
+  Gamma G2(ak);         //G2.Set(rand()/float(RAND_MAX));
+  Gamma B1(m_alpha+1);  //B1.Set(rand()/float(RAND_MAX));
+  Gamma B2(m_n);        //B2.Set(rand()/float(RAND_MAX));
+  
+  eta  = B1.Next();
+  eta /= (eta+B2.Next());
+  bn   = m_b0-std::log(eta);
+  
+  prop=ak/(ak+m_n*bn);
+  m_alpha=(prop*G1.Next()+(1-prop)*G2.Next())/bn;
+  //m_alpha=.00000001;
+}
+  
+double DPM_GibbsSampler::marglik(const ColumnVector& data,const int k){
+  double res=0.0;
+  LogAndSign ld=(2*M_PI*m_gw[k].get_ssigma()).LogDeterminant();
+
+  res -= 0.5*(ld.LogValue()
+	      +((data-m_gw[k].get_smu()).t()
+	      *m_gw[k].get_ssigma().i()
+		*(data-m_gw[k].get_smu())).AsScalar());
+
+  return res;
+}
+double DPM_GibbsSampler::margint(const ColumnVector& data){
+  LogAndSign ld;
+  double res=m_margintbase;
+  
+  ld = m_gw0.get_Nu().LogDeterminant();
+  res += ld.LogValue()*m_gw0.get_dof()/2;
+  
+  SymmetricMatrix A(m_d);
+  A << m_gw0.get_Nu()+m_gw0.get_kappa()/(1+m_gw0.get_kappa())*(data-m_gw0.get_mu())*(data-m_gw0.get_mu()).t();
+  ld = A.LogDeterminant();
+  res -= ld.LogValue()*(m_gw0.get_dof()+1)/2;
+
+  return res;
+}
+
+
+// utils
+void DPM_GibbsSampler::do_kmeans(){
+  int numiter = 100;
+  
+  Matrix means(m_d,m_k),newmeans(m_d,m_k);
+  ColumnVector nmeans(m_k);
+  
+  means=0;
+  nmeans=0;
+  
+  //    cout<<"inside kmeans"<<endl;
+  // initialise random
+  vector< pair<float,int> > rindex(m_n);
+  randomise(rindex);
+  vector<pair<float,int> >::iterator riter;
+  int nn=0,cl=1,nperclass=(int)(float(m_n)/float(m_k));
+  for(riter=rindex.begin();riter!=rindex.end();++riter){
+    means.Column(cl) += m_data.Row((*riter).second+1).t();
+    nmeans(cl) += 1;
+    m_z[(*riter).second]=cl;
+    nn++;
+    if(nn>=nperclass && cl<m_k){
+      nn=0;
+      cl++;
+    }
+  }
+  for(int m=1;m<=m_k;m++)
+    means.Column(m) /= nmeans(m);
+
+  //cout<<"kmeans init"<<endl;
+  //for(int i=0;i<n;i++)
+  //cout<<z[i]<<" ";
+  //cout<<endl;
+  
+  // iterate
+  for(int iter=0;iter<numiter;iter++){
+    // loop over datapoints and attribute z for closest mean
+    newmeans=0;
+    nmeans=0;
+    for(int i=1;i<=m_n;i++){
+      float mindist=1E20,dist=0;
+      int mm=1;
+      for(int m=1;m<=m_k;m++){
+	dist = (means.Column(m)-m_data.Row(i).t()).SumSquare();
+	if( dist<mindist){
+	  mindist=dist;
+	  mm = m;
+	}
+      }
+      m_z[i] = mm;
+      newmeans.Column(mm) += m_data.Row(i).t();
+      nmeans(mm) += 1;
+    }
+    
+    // compute means
+    for(int m=1;m<=m_k;m++){
+      if(nmeans(m)==0){
+	if(opts.numclass.value()<0) m_k--;
+	do_kmeans();
+	return;
+      }
+      newmeans.Column(m) /= nmeans(m);
+    }
+    means = newmeans;
+  }
+  
+  
+  //cout<<"kmeans end"<<endl;
+  //for(int i=0;i<n;i++)
+  //cout<<z[i]<<" ";
+  //cout<<endl;
+}    
+
+ReturnMatrix DPM_GibbsSampler::get_dataindex(){
+    ColumnVector index(m_n);
+    for(unsigned int i=0;i<m_z.size();i++)
+      index(i+1) = m_z[i];
+    index.Release();
+    return index;
+}
+ReturnMatrix DPM_GibbsSampler::get_mldataindex(){
+  ColumnVector index(m_n);
+  double lik,tmplik;
+  for(int i=1;i<=m_n;i++){
+    lik=0.0;tmplik=0;index(i) = 0;
+    for(int k=0;k<m_k;k++){
+      tmplik = m_classnd[k]*marglik(m_data.Row(i).t(),k);
+      if(tmplik>lik && m_classnd[k]>3){
+	lik = tmplik;
+	index(i) = k+1;
+      }
+    }
+  }
+  index.Release();
+  return index;
+}
diff --git a/dpm_gibbs.h b/dpm_gibbs.h
index d91cfce592a221258219ef23b26d3855ed3d9936..3cc04d1218e9e39a322103e4c8fe9862c587030e 100644
--- a/dpm_gibbs.h
+++ b/dpm_gibbs.h
@@ -3,36 +3,36 @@
 
 #include "gibbs.h"
 #include "dpmOptions.h"
-#include "miscmaths/miscmaths.h"
+#include "newran/newran.h"
 #include "miscmaths/miscprob.h"
 #include <stdlib.h>
 #include <stdio.h>
-#include <newmat.h>
-#include <newran.h>
 #include <cmath>
 
+
 using namespace NEWMAT;
 using namespace NEWRAN;
 using namespace MISCMATHS;
 using namespace DPM;
 using namespace std;
 
+
 // Gaussian-InverWishart distribution
 // p(mu,sigma)=det(sigma)^(-(nu+d)/2-1)exp(-trace(Nu*inv(sigma))/2 -kappa/2*(mu-m_mu)'inv(sigma)(mu-m_mu))
 class GaussianWishart{
- private:
+  private:
   friend std::ostream& operator << (ostream& o,GaussianWishart& g);
-  
  protected:
   ColumnVector       m_mu;
   SymmetricMatrix    m_Nu;
   float              m_kappa;
   int                m_dof;
   int                m_dim;
-
+  
   ColumnVector       m_smu;     // sample mean
   SymmetricMatrix    m_ssigma;  // sample covariance
 
+  
  public:
   GaussianWishart(){}
   GaussianWishart(const int dim):m_dim(dim){
@@ -56,7 +56,7 @@ class GaussianWishart{
   void postupdate(const vector<ColumnVector>& data,const GaussianWishart& gw0){
     ColumnVector mdat(m_dim);
     SymmetricMatrix S(m_dim),SS(m_dim);
-
+    
     float n = (float)data.size();
     m_dof   = gw0.get_dof()   + int(n);
     m_kappa = gw0.get_kappa() + n;
@@ -67,19 +67,19 @@ class GaussianWishart{
       mdat += data[i];
     }
     mdat /= n;
-
+    
     SS << S -n*mdat*mdat.t();
     SS << SS + gw0.get_kappa()*n/m_kappa * (mdat-gw0.get_mu())*(mdat-gw0.get_mu()).t();
-
+    
     m_mu    = ( gw0.get_kappa()*gw0.get_mu() + n*mdat )/m_kappa;
     m_Nu   << gw0.get_Nu() + SS;
-
+    
     sample();
   }
   void postupdate(const Matrix& data,const GaussianWishart& gw0){
     ColumnVector mdat(m_dim);
     SymmetricMatrix S(m_dim),SS(m_dim);
-
+    
     float n = (float)data.Nrows();
     m_dof   = gw0.get_dof()   + int(n);
     m_kappa = gw0.get_kappa() + n;
@@ -90,13 +90,13 @@ class GaussianWishart{
       mdat += data.Row(i).t();
     }
     mdat /= n;
-
+    
     SS << S -n*mdat*mdat.t();
     SS << SS + gw0.get_kappa()*n/m_kappa * (mdat-gw0.get_mu())*(mdat-gw0.get_mu()).t();
-
+    
     m_mu    = ( gw0.get_kappa()*gw0.get_mu() + n*mdat )/m_kappa;
     m_Nu   << gw0.get_Nu() + SS;
-
+    
     sample();
   }
   void sample(ColumnVector& mu,SymmetricMatrix& sigma){
@@ -107,7 +107,7 @@ class GaussianWishart{
     m_ssigma = iwishrnd(m_Nu.i(),m_dof);
     m_smu    = mvnrnd(m_mu.t(),m_ssigma/m_kappa).t();
   }
-  void print(ostream& os)const{
+  void print(ostream& os)const{ 
     os << "Gaussian-InverseWishart distribution" << endl;
     os << "mean       : " << m_mu.t();
     os << "variance   : " << m_Nu.Row(1);
@@ -129,7 +129,7 @@ class GaussianWishart{
     m_kappa  = rhs.m_kappa;
     m_dof    = rhs.m_dof;
     m_dim    = rhs.m_dim;
-
+    
     m_smu    = rhs.m_smu;
     m_ssigma = rhs.m_ssigma;
 
@@ -137,18 +137,17 @@ class GaussianWishart{
   }
 
 };
-std::ostream& operator << (ostream& o,GaussianWishart& g){
-  g.print(o);
-  return o;
-}
 
-bool compare(const pair<int,float> &p1,const pair<int,float> &p2){
-  return (p1.second < p2.second) ? true : false;
-}
+//bool compare(const pair<int,float> &p1,const pair<int,float> &p2){
+//return (p1.second < p2.second) ? true : false;
+//}
 
 
 class DPM_GibbsSampler : public GibbsSampler
 {
+ private:
+  friend std::ostream& operator << (ostream& o,DPM_GibbsSampler& g);
+  
  protected:
   DPM::dpmOptions& opts;
 
@@ -168,9 +167,9 @@ class DPM_GibbsSampler : public GibbsSampler
   // data-related quantities
   vector<int>              m_classnd;
   int                      m_k;
-  float                    m_margintbase;
+  double                   m_margintbase;
   
-  vector< pair<int,float> > randindex;
+  vector< pair<float,int> > randindex;
 
   // samples
   vector<float>            m_sample_alpha;
@@ -178,307 +177,55 @@ class DPM_GibbsSampler : public GibbsSampler
   vector<double>           m_sample_likelihood;
   double                   m_likelihood;
   int                      m_nsamples;
+  vector<float>            m_mean_z;
+
+  const Matrix&            m_data;
 
 public:
   DPM_GibbsSampler(const Matrix& data,int numiter,int burnin,int sampleevery):
-    GibbsSampler(data,numiter,burnin,sampleevery),
-    opts(DPM::dpmOptions::getInstance()){
+    GibbsSampler(numiter,burnin,sampleevery),
+      opts(DPM::dpmOptions::getInstance()),m_data(data){
+      m_n = m_data.Nrows();
+      m_d = m_data.Ncols();
+
     m_nsamples = (int)floor( (numiter - burnin) / sampleevery );
 
     m_sample_alpha.resize(m_nsamples);
     m_sample_k.resize(m_nsamples);
     m_sample_likelihood.resize(m_nsamples);
-  }
-  ~DPM_GibbsSampler(){}
-
-  void init(){
-    // fix fixed parameters
-    m_a0       = 1.0;
-    m_b0       = 1.0;
-    m_S0       << 1000.0*Identity(m_d);//cov(m_data);
-    m_N0       << Identity(m_d);///(m_nu0-m_d-1);
-    m_m0       = mean(m_data,1).t();
-    m_n0       = 1; 
-
-    // initialise all other parameters
-    m_alpha    = 1.0;
-    m_k        = opts.numclass.value();
-
-    float kappa0   = 1.0;
-    int nu0        = m_d;
-    SymmetricMatrix Nu0(m_d);
-    Nu0 << m_n0*m_N0;//.01*m_d*Identity(m_d);//cov(m_data);//*(m_nu0-m_d-1);
-    ColumnVector mu0(m_d); 
-    mu0 = m_m0;
-    
-    m_gw0      = GaussianWishart(mu0,Nu0,nu0,kappa0);
-
-    
-    if(m_k < 0){
-      if(opts.init_class.value() == "oneperdata"){
-	cout << "Initialise with one class per data"<<endl;
-	m_k = m_n;
-	// set parameters
-	for(int i=1;i<=m_n;i++){
-	  GaussianWishart gw(m_d);
-	  gw.postupdate(m_data.SubMatrix(i,i,1,m_d),m_gw0);
-	  m_gw.push_back(gw);
-	  m_z.push_back(i-1);         // one data point per class
-	  m_classnd.push_back(1);
-	}
-	cout << *this;
-      }
-      else if (opts.init_class.value() == "one"){
-	cout << "initialise with one big class"<<endl;
-	m_k = 1;
-	// initialise with one big class
-	GaussianWishart gw(m_d);
-	gw.postupdate(m_data,m_gw0);
-	m_gw.push_back(gw);
-	for(int i=0;i<m_data.Nrows();i++)m_z.push_back(0);
-	m_classnd.push_back(m_data.Nrows());
-	cout << *this;
-      }
-      else{ // kmeans initialisation
-	cout << "Initialise using kmeans" << endl;
-	m_z.resize(m_n);
-	m_k = 10;
-	kmeans(m_data,m_z,m_k);
-	cout<<"done"<<endl;
-	for(int k=1;k<=m_k;k++){
-	  GaussianWishart gw(m_d);
-	  vector<ColumnVector> dat;
-	  for(int i=1;i<=m_n;i++)
-	    if(m_z[i-1] == k){
-	      dat.push_back(m_data.Row(i).t());
-	      m_z[i-1] -- ;
-	    }
-	  gw.postupdate(dat,m_gw0);
-	  m_gw.push_back(gw);
-	  m_classnd.push_back((int)dat.size());
-	}
-	cout << *this;
-      }
-    }
-    else{
-      m_z.resize(m_n);
-      kmeans(m_data,m_z,m_k);
-
-      for(int k=1;k<=m_k;k++){
-	GaussianWishart gw(m_d);
-	vector<ColumnVector> dat;
-	for(int i=1;i<=m_n;i++)
-	  if(m_z[i-1] == k){
-	    dat.push_back(m_data.Row(i).t());
-	    m_z[i-1] -- ;
-	  }
-	//OUT(dat.size());
-	gw.postupdate(dat,m_gw0);
-	m_gw.push_back(gw);
-	m_classnd.push_back((int)dat.size());
-      }
-
-    }
-
-    m_margintbase = m_d/2*(log(m_gw0.get_kappa()/(1+m_gw0.get_kappa()))-log(M_PI)) 
-      + lgam(float(nu0+1)/2.0) -lgam(float(nu0+1-m_d)/2.0);
-
-    OUT(m_margintbase);
-    //print();
-
-
-    // randomised index for loop over data items
-    randindex.resize(m_n);
-
-  }
-  void sample_parameters(){
-    // sample indicators
-    //cout<<"sample z"<<endl;
-    sample_z();
-    // sample mean and variance of each class
-    //cout<<"sample gw"<<endl;
-    sample_gw();
-
-    cout << *this;
-
-  }
-  void sample_hyperparameters(){
-    // sample hyperpriors
-    //cout<<"sample gw0"<<endl;
-    sample_gw0();
-    // sample alpha
-    //cout<<"sample alpha"<<endl;
-    sample_alpha();
-  }
-  void randomise(vector< pair<int,float> >& r){
-    for(int i=0;i<m_n;i++){
-      pair<int,float> p(i,rand());
-      r[i]=p;
-    }
-    sort(r.begin(),r.end(),compare);
-  }
-  // sample indicator variables
-  void sample_z(){
-    ColumnVector datapoint(m_d);
-
-    randomise(randindex);
-
-    float extra_finite1 = opts.numclass.value() < 0 ? 0.0 : m_alpha/float(m_k);
-    float extra_finite2 = opts.numclass.value() < 0 ? 1.0 : 0.0;
-
-    vector< pair<int,float> >::iterator iter;
-    for(iter=randindex.begin(); iter!=randindex.end(); iter++){
-      ColumnVector cumsum(m_k+1);
-
-      datapoint=m_data.Row((*iter).first+1).t();
-      int oldz=m_z[(*iter).first],newz=oldz;
-      m_classnd[oldz] -= 1;
-
-      //cout<<"-----"<<endl;
-      // compute class weights
-      float sum=0.0;
-      for(int k=0;k<m_k;k++){
-	sum += (m_classnd[k]+extra_finite1)*marglik(datapoint,k);
-	cumsum(k+1) = sum;
-      }
-      sum += m_alpha*margint(datapoint) * extra_finite2;
-      cumsum(m_k+1) = sum;
-      // sample z using the weights
-      float U=rand()/float(RAND_MAX);
-      U *= sum;
-      for(int k=1;k<=m_k+1;k++){
-	if(U<cumsum(k)){
-	  newz=k-1;
-	  break;
-	}
-      }
-      m_z[(*iter).first] = newz;
-      
-      //cout<<"-----"<<endl;
-
-      if( newz >= m_k ){ // add a new class
-	//cout<<"ADD A NEW CLASS"<<endl;
-	m_k++;
-	m_classnd.push_back(1);
-	GaussianWishart gw(m_d);
-	gw.postupdate(datapoint.t(),m_gw0);
-	m_gw.push_back(gw);
-      }
-      else{
-	m_classnd[newz] += 1;
-      }
-    }// end loop over data points
-
-    // delete empty classes if in infinite mode
-    if(opts.numclass.value()<0){
-      for(int k=m_k-1;k>=0;k--){
-	if(m_classnd[k] == 0){
-	  for(int i=0;i<m_n;i++)
-	    if(m_z[i]>k)m_z[i]--;
-	  for(int kk=k;kk<m_k-1;kk++){
-	    m_classnd[kk]=m_classnd[kk+1];
-	    m_gw[kk]=m_gw[kk+1];
-	  }
-	  m_classnd.pop_back();
-	  m_gw.pop_back();
-	  m_k--;
-	}
-      }
-    }
+    m_mean_z.resize(m_n);
 
+    Random::Set(rand() / float(RAND_MAX));
   }
+  ~DPM_GibbsSampler(){}
 
-  void sample_gw(){
-    // update classes posteriors
-    vector< vector<ColumnVector> > data;
-    data.resize(m_k);    
-    
-    m_likelihood = 0;
-    for(int i=0;i<m_n;i++){
-      data[ m_z[i] ].push_back(m_data.Row(i+1).t());
-      m_likelihood += -std::log(marglik(m_data.Row(i+1).t(),m_z[i]));
-    }
-
-    for(int k=0;k<m_k;k++){
-      if(data[k].size()>0)
-	m_gw[k].postupdate(data[k],m_gw0);
-    }
-    
-  }
-
-
-  void sample_gw0(){
-    SymmetricMatrix Nu0(m_d),A(m_d),S(m_d);
-    ColumnVector a(m_d),mu0(m_d);    
-    float B=0;
-    
-    A=0;a=0;
-    for(int k=0;k<m_k;k++){
-      S = m_gw[k].get_ssigma().i();
-      a += S*m_gw[k].get_smu();
-      A << A+S;
-      B += ((m_gw[k].get_smu()-m_gw0.get_mu()).t()*S*(m_gw[k].get_smu()-m_gw0.get_mu())).AsScalar();
-    }
-    S << A+m_N0.i();
-    A << (A+m_S0.i()).i();
-    a = A*(a+m_S0.i()*m_m0);
-    
-
-    Nu0 = wishrnd(S.i(),(m_k+1)*m_gw0.get_dof());
-    mu0 = mvnrnd(a.t(),A).t();
-
-    m_gw0.set_Nu(Nu0);
-    m_gw0.set_mu(mu0);
-      
-    Gamma G(1+m_k*m_d/2); //G.Set(rand()/float(RAND_MAX));
-    m_gw0.set_kappa(G.Next()*2/(1+B));
-    //m_gw0.set_kappa(1.0);
-
-    
-  }
-
-  // sample from alpha using additional variable eta
-  void sample_alpha(){
-    float eta,prop;
-    float ak=m_a0+m_k-1,bn;
-
-    Gamma G1(ak+1);       //G1.Set(rand()/float(RAND_MAX));
-    Gamma G2(ak);         //G2.Set(rand()/float(RAND_MAX));
-    Gamma B1(m_alpha+1);  //B1.Set(rand()/float(RAND_MAX));
-    Gamma B2(m_n);        //B2.Set(rand()/float(RAND_MAX));
-
-    eta  = B1.Next();
-    eta /= (eta+B2.Next());
-    bn   = m_b0-std::log(eta);
+  // parent class function definitions
+  void sample_parameters();
+  void sample_hyperparameters();
 
-    prop=ak/(ak+m_n*bn);
-    m_alpha=(prop*G1.Next()+(1-prop)*G2.Next())*bn;
-  }
+  // initialisation functions
+  void init();
+  void init_oneperdata();
+  void init_onebigclass();
+  void init_kmeans(const int k=10);
+  void init_random(const int k=10);
   
-  float marglik(const ColumnVector& data,const int k){
-    float res;
-    res=normpdf(data,m_gw[k].get_smu(),m_gw[k].get_ssigma());
-
-    //OUT(res);
-
-    return res;
-  }
-  float margint(const ColumnVector& data){
-    LogAndSign ld;
-    float res=m_margintbase;
-
-    ld = m_gw0.get_Nu().LogDeterminant();
-    res += ld.LogValue()*m_gw0.get_dof()/2;
-
-    SymmetricMatrix A(m_d);
-    A << m_gw0.get_Nu()+m_gw0.get_kappa()/(1+m_gw0.get_kappa())*(data-m_gw0.get_mu())*(data-m_gw0.get_mu()).t();
-    ld = A.LogDeterminant();
-    res -= ld.LogValue()*(m_gw0.get_dof()+1)/2;
-
-    //OUT(exp(res));
-
-    return std::exp(res);
-  }
+  // sample model parameters
+  void sample_z();
+  void sample_gw();
+  void sample_gw0();
+  void sample_alpha();
+
+  // utils
+  double marglik(const ColumnVector&,const int);
+  double margint(const ColumnVector&);
+  void do_kmeans();
+  ReturnMatrix get_dataindex();
+  ReturnMatrix get_mldataindex();
+
+  int get_numclass()const{return m_k;}
+
+  // io
   void print(ostream& os){
     os << "-------fixed parameters-------"<<endl;
     os << "a0     = "<<m_a0<<endl;
@@ -501,25 +248,28 @@ public:
     os << "kappa0 = "<<m_gw0.get_kappa()<<endl;
     os << "-------class-parameters-------"<<endl;
     for(int i=0;i<m_k;i++){
-      os << "cluster "<<i<<endl;
-      os << "n\t=\t"<<m_classnd[i]<<endl;
-      os << m_gw[i];
-      os << endl;
+      //os << "cluster "<<i<<endl;
+      //os << "n\t=\t"<<m_classnd[i]<<endl;
+      os <<m_classnd[i]<<" ";
+      //os << m_gw[i];
+      //os << endl;
     }
-  }
-
-  
+    os << endl;
+  }  
   void save(){
     string logsamples   = opts.logfile.value() + ".samples";
     string logmeans     = opts.logfile.value() + ".means";
     string logvariances = opts.logfile.value() + ".variances";
-
+    string zzz = opts.logfile.value() + ".z";
+    
     ofstream of_s(logsamples.c_str());
     ofstream of_m(logmeans.c_str());
     ofstream of_v(logvariances.c_str());
+    ofstream of_z(zzz.c_str());
+    
     double evidence=0;
     double maxlog=0;
-
+    
     of_s << "k\talpha\tlik\n";
     for (unsigned int i=0;i<m_sample_likelihood.size();i++){
       //OUT(i);
@@ -533,16 +283,22 @@ public:
     for(unsigned int i=0;i<m_sample_likelihood.size();i++){
       evidence += std::exp(m_sample_likelihood[i]-maxlog);
     }
-
+    
     // store means and variances
     for(int k=0;k<m_k;k++){
       of_m << m_gw[k].get_smu().t();
       of_v << m_gw[k].get_ssigma().t();
     }
-
+    
     evidence = -log((float)m_sample_likelihood.size()) + maxlog + log(evidence);
     cout<<m_k<<" ";
-    cout<<evidence<<endl;;
+    cout<<evidence<<endl;
+    
+    ColumnVector mlz(m_n);
+    mlz = get_mldataindex();
+    of_z << mlz;
+    
+    cout<<"final k="<< mlz.MaximumAbsoluteValue()<<endl;
     
   }
   void record(const int samp){
@@ -550,86 +306,12 @@ public:
     m_sample_likelihood[samp] = m_likelihood;
     m_sample_k[samp]          = m_k;
     m_sample_alpha[samp]      = m_alpha;
+    for(int i=0;i<m_n;i++)
+      m_mean_z[i] += m_z[i];
   }
-
-  void kmeans(const Matrix& data,vector<int>& z,const int k){
-    int numiter = 100;
-    int n = data.Nrows();
-    int d = data.Ncols();
-
-    Matrix means(d,k),newmeans(d,k);
-    ColumnVector nmeans(k);
-    //z.resize(n);
-    
-    means=0;
-    nmeans=0;
-
-    //    cout<<"inside kmeans"<<endl;
-    // initialise random
-    vector< pair<int,float> > rindex(n);
-    randomise(rindex);
-    vector<pair<int,float> >::iterator riter;
-    int nn=0,cl=1,nperclass=(int)(float(n)/float(k));
-    for(riter=rindex.begin();riter!=rindex.end();riter++){
-      means.Column(cl) += data.Row((*riter).first+1).t();
-      nmeans(cl) += 1;
-      z[(*riter).first]=cl;
-      nn++;
-      if(nn>=nperclass && cl<k){
-	nn=0;
-	cl++;
-      }
-    }
-    for(int m=1;m<=k;m++)
-      means.Column(m) /= nmeans(m);
-
-    //cout<<"kmeans init"<<endl;
-    //for(int i=0;i<n;i++)
-    //cout<<z[i]<<" ";
-    //cout<<endl;
-
-    // iterate
-    for(int iter=0;iter<numiter;iter++){
-      // loop over datapoints and attribute z for closest mean
-      newmeans=0;
-      nmeans=0;
-      for(int i=1;i<=n;i++){
-	float mindist=1E20,dist=0;
-	int mm=1;
-	for(int m=1;m<=k;m++){
-	  dist = (means.Column(m)-data.Row(i).t()).SumSquare();
-	  if( dist<mindist){
-	    mindist=dist;
-	    mm = m;
-	  }
-	}
-	z[i] = mm;
-	newmeans.Column(mm) += data.Row(i).t();
-	nmeans(mm) += 1;
-      }
-
-      // compute means
-      for(int m=1;m<=k;m++){
-	if(nmeans(m)==0){
-	  kmeans(data,z,k);
-	  return;
-	}
-	newmeans.Column(m) /= nmeans(m);
-      }
-      means = newmeans;
-    }
-
-    //OUT(n);
-    //OUT(d);
-    //OUT(nmeans.t());
-    //OUT(newmeans);
-
-    //cout<<"kmeans end"<<endl;
-    //for(int i=0;i<n;i++)
-    //cout<<z[i]<<" ";
-    //cout<<endl;
-  }    
-
+  
+  
+  
 };
 
 
diff --git a/gibbs.h b/gibbs.h
index c78e4c8355853823c4cbbd141776973b487ff713..9cb08a0585ff4e733a11d674639c83e0f5afdf31 100644
--- a/gibbs.h
+++ b/gibbs.h
@@ -5,17 +5,14 @@
 #if !defined(_GIBBS_H)
 #define _GIBBS_H
 
-#include "newmat.h"
-#include "newran.h"
-#include "miscmaths/miscmaths.h"
+#include <stdlib.h>
+#include <stdio.h>
+#include <iostream.h>
 
-using namespace NEWMAT;
-using namespace NEWRAN;
+using namespace std;
 
 class GibbsSampler
 {
- private:
-  friend std::ostream& operator << (ostream& o,GibbsSampler& g);
  protected:
 
   int m_numiter;
@@ -24,25 +21,17 @@ class GibbsSampler
   int m_n;
   int m_d;
 
-  const Matrix& m_data;
-
  public:
-  GibbsSampler(const Matrix& data,int numiter,int burnin,int sampleevery):
-    m_numiter(numiter),m_burnin(burnin),m_sampleevery(sampleevery),m_data(data){
-    
-    m_n=data.Nrows();
-    m_d=data.Ncols();
-  }
+  GibbsSampler(int numiter,int burnin,int sampleevery):
+    m_numiter(numiter),m_burnin(burnin),m_sampleevery(sampleevery){}
   virtual ~GibbsSampler(){}
 
   virtual void init() = 0 ;
   virtual void record(const int) = 0;
   virtual void sample_parameters() = 0;
   virtual void sample_hyperparameters() = 0;
-  virtual void print(ostream&) = 0;
 
   void  run(){
-    Random::Set(rand() / float(RAND_MAX));
 
     int recordcount=0;
 
@@ -52,7 +41,7 @@ class GibbsSampler
       //cout<<"-----------"<<endl;
       sample_parameters();
       sample_hyperparameters();
-      //print();
+      
     }
 
     // m_numiter=2;
@@ -80,10 +69,6 @@ class GibbsSampler
 
 };
 
-std::ostream& operator << (ostream& o,GibbsSampler& g){
-  g.print(o);
-  return o;
-}
 
 
 #endif