Skip to content
Snippets Groups Projects
Commit b2c0dd4a authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

plot PCA space (and always do PCA before dpm)

parent a07957ed
No related branches found
No related tags found
No related merge requests found
......@@ -480,48 +480,44 @@ ReturnMatrix dimreduce(Matrix& data,const int& numdims){
}
void dimreduce(vector<Matrix>& data,Matrix& proj){
bool do_dimred=true;
for(unsigned int i=0;i<data.size();i++)
if(dimred.value()>=data[i].Ncols() || dimred.value()<0){do_dimred=false;break;}
if(!do_dimred){
if(verbose.value()) cout << "no dimensionality reduction" << endl;
proj=IdentityMatrix(data[0].Ncols());
if(data[0].Ncols()>10){
cout << endl;
cout << "WARNING: data is high dimensional. We recommend using --dimred=10 or less" << endl << endl;
}
return;
int _dimred = dimred.value();
if(_dimred<0 || _dimred>data[0].Ncols()){
_dimred = data[0].Ncols();
do_dimred=false;
}
// reduce data
else{
if(verbose.value()) cout << "requested dimensions: " << dimred.value() << endl;
if(dimred.value()>10){
cout << endl;
cout << "WARNING: We recommend using --dimred=10 or less" << endl << endl;
}
vector<int> nr(data.size());
int nrtotal=0;
int nc = data[0].Ncols();
for(unsigned int i=0;i<data.size();i++){
nr[i] = data[i].Nrows();
nrtotal += nr[i];
}
Matrix concatdata(nrtotal,nc);
int currow = 1;
for(unsigned int i=0;i<data.size();i++){
concatdata.SubMatrix(currow,currow+nr[i]-1,1,nc) = data[i];
currow += nr[i];
}
proj=dimreduce(concatdata,dimred.value());
currow = 1;
for(unsigned int i=0;i<data.size();i++){
data[i] = concatdata.SubMatrix(currow,currow+nr[i]-1,1,dimred.value());
currow += nr[i];
}
if(!do_dimred)
if(verbose.value())
cout << "no dimensionality reduction" << endl;
// reduce data
if(verbose.value()) cout << "requested dimensions: " << _dimred << endl;
if(_dimred>10){
cout << endl;
cout << "WARNING: We recommend using --dimred=10 or less" << endl << endl;
}
vector<int> nr(data.size());
int nrtotal=0;
int nc = data[0].Ncols();
for(unsigned int i=0;i<data.size();i++){
nr[i] = data[i].Nrows();
nrtotal += nr[i];
}
//write_ascii_matrix(data[0],"data4");
Matrix concatdata(nrtotal,nc);
int currow = 1;
for(unsigned int i=0;i<data.size();i++){
concatdata.SubMatrix(currow,currow+nr[i]-1,1,nc) = data[i];
currow += nr[i];
}
proj=dimreduce(concatdata,_dimred);
currow = 1;
for(unsigned int i=0;i<data.size();i++){
data[i] = concatdata.SubMatrix(currow,currow+nr[i]-1,1,_dimred);
currow += nr[i];
}
}
......@@ -837,8 +833,7 @@ int main (int argc, char *argv[]){
if(beta.value()>0)
gs.add_spatial_prior(adj[0]);
gs.set_spcparam(beta.value());
if(dimred.value()>0)
gs.set_projector(proj);
gs.set_projector(proj);
if(verbose.value())
cout << ".....init......";
......@@ -849,11 +844,6 @@ int main (int argc, char *argv[]){
if(verbose.value())
cout << gs << endl;
// if(save.value()){
// if(verbose.value())
// cout << "save samples";
// gs.save(dirname+"/subject000");
// }
cout<<endl;
if(verbose.value())
cout << "------> postprocessing" << endl;
......@@ -870,9 +860,7 @@ int main (int argc, char *argv[]){
if(report.value()){
if(verbose.value())
cout << "create web report" << endl;
if(dimred.value()>0)
cr.project_data(gs);
cr.set_stats(gs);
cr.set_coord(coord);
cr.create_report(dirname+"/report");
......@@ -888,8 +876,8 @@ int main (int argc, char *argv[]){
if(beta.value()>0)
gs.add_spatial_prior(adj);
gs.set_spcparam(beta.value());
if(dimred.value()>0)
gs.set_projector(proj);
gs.set_projector(proj);
if(verbose.value())
cout << ".....init......";
gs.init(numclass.value(),init_class.value());
......@@ -922,8 +910,7 @@ int main (int argc, char *argv[]){
if(report.value()){
if(verbose.value())
cout << "create web report" << endl;
if(dimred.value()>0)
cr.project_data(gs);
cr.set_stats(gs);
cr.set_coord(coord);
cr.create_report(dirname+"/report");
......
......@@ -184,7 +184,11 @@ void CopainReport::plot_table_fstats(ofstream& htmlfile){
}
void CopainReport::create_report(const string& dirname){
///////////////////////////////////////////////////////////////////////////////////////////////////////
// setup html report file
///////////////////////////////////////////////////////////////////////////////////////////////////////
string htmlfilename = dirname + "/index.html";
ofstream htmlfile(htmlfilename.c_str());
......@@ -204,6 +208,9 @@ void CopainReport::create_report(const string& dirname){
htmlfile << "</font>" << endl;
}
///////////////////////////////////////////////////////////////////////////////////////////////////////
// plot histograms of the raw data
///////////////////////////////////////////////////////////////////////////////////////////////////////
htmlfile << "<hr>" << endl;
htmlfile << "<h3>Raw Data</h3><br><br>" << endl;
......@@ -224,37 +231,70 @@ void CopainReport::create_report(const string& dirname){
}
// plot fits
///////////////////////////////////////////////////////////////////////////////////////////////////////
// plot fits to the PCA processed data
///////////////////////////////////////////////////////////////////////////////////////////////////////
htmlfile << "<hr>" << endl;
htmlfile << "<hr>" << endl;
htmlfile << "<h3>GMM FIT to PCA processed data</h3><br><br>" << endl;
for(int t=0;t<data[0].Ncols();t++){
currow=1;
for(unsigned int s=0;s<subjnames.size();s++){
dat.SubMatrix(currow,currow+rawdata[s].Nrows()-1,1,1) = data[s].SubMatrix(1,data[s].Nrows(),t+1,t+1);
currow += rawdata[s].Nrows();
}
plot_fit(htmlfile,dirname,"data"+num2str(t+1)+"_gmm_pca",
dat,stats.get_means_pca(t+1),stats.get_vars_pca(t+1),stats.get_weights(),
"");
}
///////////////////////////////////////////////////////////////////////////////////////////////////////
// plot fits in data space
///////////////////////////////////////////////////////////////////////////////////////////////////////
htmlfile << "<h3>GMM FIT to preprocessed data</h3><br><br>" << endl;
htmlfile << "Preprocessing consists of log transforming the data and adding gaussian noise in each dimension (std=0.05) <br>" << endl;
htmlfile << "When there is more than one subject, histogram matching is done between subjects in each dimension (i.e. for each target) <br><br>" << endl;
project_data();
for(unsigned int t=0;t<targetnames.size();t++){
currow=1;
for(unsigned int s=0;s<subjnames.size();s++){
dat.SubMatrix(currow,currow+rawdata[s].Nrows()-1,1,1) = data[s].SubMatrix(1,data[s].Nrows(),t+1,t+1);
currow += rawdata[s].Nrows();
}
plot_fit(htmlfile,dirname,"data"+num2str(t+1)+"_gmm",
dat,stats.get_means(t+1),stats.get_vars(t+1),stats.get_weights(),
targetnames[t]);
}
// // plot table summary
///////////////////////////////////////////////////////////////////////////////////////////////////////
// plot table summary
///////////////////////////////////////////////////////////////////////////////////////////////////////
plot_table_clusters(htmlfile);
plot_table_fstats(htmlfile);
///////////////////////////////////////////////////////////////////////////////////////////////////////
// plot clustering
///////////////////////////////////////////////////////////////////////////////////////////////////////
// htmlfile << "<hr>" << endl;
// htmlfile << "<h3>Hard clustering snapshot</h3><br><br>" << endl;
///////////////////////////////////////////////////////////////////////////////////////////////////////
// save coords
///////////////////////////////////////////////////////////////////////////////////////////////////////
save_coord(dirname);
///////////////////////////////////////////////////////////////////////////////////////////////////////
// create subject-wise reports
///////////////////////////////////////////////////////////////////////////////////////////////////////
for(unsigned int i=0;i<subjnames.size();i++){
string ihtmlfilename = dirname + "/copain_report_"+num2str(i)+".html";
ofstream ihtmlfile(ihtmlfilename.c_str());
......
......@@ -19,13 +19,20 @@ class Stats {
public:
Stats () {}
void set_stats(const GWDPM_GibbsSampler& gs){
vector<SymmetricMatrix> s;
vector<SymmetricMatrix> s_pca;
means = gs.get_map_means();
s = gs.get_map_variances();
means_pca = gs.get_map_means();
s_pca = gs.get_map_variances();
means = means_pca;
s = s_pca;
gs.project_classes(means,s);
projector = gs.get_projector();
weights = gs.get_map_proportions();
z = gs.get_map_z();
......@@ -33,12 +40,10 @@ public:
nclasses = (int)s.size();
zs.clear();zs.push_back(z);
fuzzy_zs.clear();fuzzy_zs.push_back(fuzzy_z);
variances.resize(s.size());
ColumnVector w(s.size());
for(unsigned int cl=0;cl<s.size();cl++){
variances[cl].ReSize(s[cl].Nrows());
......@@ -47,15 +52,27 @@ public:
w(cl+1) = weights[cl];
}
variances_pca.resize(s_pca.size());
for(unsigned int cl=0;cl<s_pca.size();cl++){
variances_pca[cl].ReSize(s_pca[cl].Nrows());
for(int tgt=1;tgt<=s_pca[cl].Nrows();tgt++)
variances_pca[cl](tgt) = s_pca[cl](tgt,tgt);
}
compatible_subjects=true;
}
void set_stats(const GWHDPM_GibbsSampler& gs){
vector<SymmetricMatrix> s;
vector<SymmetricMatrix> s_pca;
means = gs.get_map_means();
s = gs.get_map_variances();
means_pca = gs.get_map_means();
s_pca = gs.get_map_variances();
means = means_pca;
s = s_pca;
gs.project_classes(means,s);
projector = gs.get_projector();
weights = gs.get_map_proportions();
zs = gs.get_map_z();
......@@ -98,6 +115,13 @@ public:
w(cl+1) = weights[cl];
}
variances_pca.resize(s_pca.size());
for(unsigned int cl=0;cl<s_pca.size();cl++){
variances_pca[cl].ReSize(s_pca[cl].Nrows());
for(int tgt=1;tgt<=s_pca[cl].Nrows();tgt++)
variances_pca[cl](tgt) = s_pca[cl](tgt,tgt);
}
}
......@@ -138,6 +162,25 @@ public:
ret.Release();
return ret;
}
ReturnMatrix get_means_pca(const int& j)const{
ColumnVector ret(nclasses);
for(int cl=0;cl<nclasses;cl++){
ret(cl+1) = means_pca[cl](j);
}
ret.Release();
return ret;
}
ReturnMatrix get_vars_pca(const int& j)const{
ColumnVector ret(nclasses);
for(int cl=0;cl<nclasses;cl++){
ret(cl+1) = variances_pca[cl](j);
}
ret.Release();
return ret;
}
ReturnMatrix get_weights()const{
ColumnVector ret(nclasses);
for(int cl=0;cl<nclasses;cl++){
......@@ -227,9 +270,15 @@ public:
}
Matrix get_projector()const{return projector;}
private:
vector<ColumnVector> means;
vector<ColumnVector> variances;
vector<ColumnVector> means_pca;
vector<ColumnVector> variances_pca;
vector<float> weights;
int nclasses;
vector<ColumnVector> zs; // all subjects
......@@ -237,6 +286,8 @@ private:
ColumnVector z; // group or one subject
Matrix fuzzy_z;
Matrix projector;
bool compatible_subjects;
};
......@@ -259,8 +310,8 @@ public:
for(unsigned int i=0;i<data.size();i++)
data[i] = (P*data[i].t()).t();
}
void project_data(const GWHDPM_GibbsSampler& gs){
Matrix P = gs.get_projector();
void project_data(){
Matrix P = stats.get_projector();
for(unsigned int i=0;i<data.size();i++)
data[i] = (P*data[i].t()).t();
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment