Skip to content
Snippets Groups Projects
Commit 8d1b7bcc authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

added kmeans clustering to ccops

parent 6b347d5a
No related branches found
No related tags found
No related merge requests found
......@@ -70,6 +70,100 @@ using namespace CCOPS;
}
bool compare(const pair<float,int> &r1,const pair<float,int> &r2){
return (r1.first<r2.first);
}
void randomise(vector< pair<float,int> >& r){
for(unsigned int i=1;i<=r.size();i++){
pair<float,int> p(rand()/float(RAND_MAX),i);
r[i-1]=p;
}
sort(r.begin(),r.end(),compare);
}
void do_kmeans(const Matrix& data,ColumnVector& y,const int k){
int numiter=200;
if(data.Nrows() != (int)y.Nrows()){
y.ReSize(data.Nrows());
}
int n = data.Nrows();
int d = data.Ncols();
Matrix means(d,k),newmeans(d,k);
ColumnVector nmeans(k);
means=0;
nmeans=0;
// initialise random
vector< pair<float,int> > rindex(n);
randomise(rindex);
vector<pair<float,int> >::iterator riter;
int nn=0,cl=1,nperclass=(int)(float(n)/float(k));
for(riter=rindex.begin();riter!=rindex.end();++riter){
means.Column(cl) += data.Row((*riter).second).t();
nmeans(cl) += 1;
y((*riter).second)=cl;
nn++;
if(nn>=nperclass && cl<k){
nn=0;
cl++;
}
}
for(int m=1;m<=k;m++)
means.Column(m) /= nmeans(m);
// iterate
for(int iter=0;iter<numiter;iter++){
// loop over datapoints and attribute z for closest mean
newmeans=0;
nmeans=0;
for(int i=1;i<=n;i++){
float mindist=1E20,dist=0;
int mm=1;
for(int m=1;m<=k;m++){
dist = (means.Column(m)-data.Row(i).t()).SumSquare();
if( dist<mindist){
mindist=dist;
mm = m;
}
}
y(i) = mm;
newmeans.Column(mm) += data.Row(i).t();
nmeans(mm) += 1;
}
// compute means
for(int m=1;m<=k;m++){
if(nmeans(m)==0){
cout << "Only found " << k-1 << " clusters!!!" << endl;
do_kmeans(data,y,k-1);
return;
}
newmeans.Column(m) /= nmeans(m);
}
means = newmeans;
}
}
void kmeans_reord(const Matrix& A,ColumnVector& r,ColumnVector& y,const int k){
do_kmeans(A,y,k);
vector< pair<float,int> > myvec2;
for(int i=1;i<=A.Nrows();i++){
pair<int,int> mypair;
mypair.first=y(i);
mypair.second=i;
myvec2.push_back(mypair);
}
sort(myvec2.begin(),myvec2.end());
r.ReSize(A.Nrows());
y.ReSize(A.Nrows());
for(int i=1;i<=A.Nrows();i++){
y(i)=myvec2[i-1].first;
r(i)=myvec2[i-1].second;
}
}
void rem_zrowcol(const Matrix& myOMmat,const Matrix& coordmat,const Matrix& tractcoordmat,const bool coordbool,const bool tractcoordbool,Matrix& newOMmat,Matrix& newcoordmat, Matrix& newtractcoordmat)
{
......@@ -80,7 +174,7 @@ void rem_zrowcol(const Matrix& myOMmat,const Matrix& coordmat,const Matrix& trac
cerr<< "Checking for all zero rows"<<endl;
cout<< "Checking for all zero rows"<<endl;
for(int i=1;i<=myOMmat.Nrows();i++){
dimsum=0;
for(int j=1;j<=myOMmat.Ncols();j++){
......@@ -91,7 +185,7 @@ void rem_zrowcol(const Matrix& myOMmat,const Matrix& coordmat,const Matrix& trac
cerr<< "Checking for all zero cols"<<endl;
cout<< "Checking for all zero cols"<<endl;
for(int j=1;j<=myOMmat.Ncols();j++){
dimsum=0;
for(int i=1;i<=myOMmat.Nrows();i++){
......@@ -110,7 +204,7 @@ void rem_zrowcol(const Matrix& myOMmat,const Matrix& coordmat,const Matrix& trac
}
int zrowcounter=0,zcolcounter=0,nzrowcounter=1,nzcolcounter=1;
cerr<<"Forming New Matrix"<<endl;
cout<<"Forming New Matrix"<<endl;
for(int j=1;j<=myOMmat.Ncols();j++){
zrowcounter=0;
nzrowcounter=1;
......@@ -150,7 +244,7 @@ void rem_zrowcol(const Matrix& myOMmat,const Matrix& coordmat,const Matrix& trac
if(coordbool){
cerr<<"Updating Seed Coordinates"<<endl;
cout<<"Updating Seed Coordinates"<<endl;
zrowcounter=0;nzrowcounter=1;
if(zerorows.size()>0){//Are there any zero rows?
for(int i=1;i<=coordmat.Nrows();i++){
......@@ -169,7 +263,7 @@ void rem_zrowcol(const Matrix& myOMmat,const Matrix& coordmat,const Matrix& trac
}
if(tractcoordbool){
cerr<<"Updating Tract Coordinates"<<endl;
cout<<"Updating Tract Coordinates"<<endl;
zcolcounter=0;nzcolcounter=1;
if(zerocols.size()>0){//Are there any zero cols?
for(int i=1;i<=tractcoordmat.Nrows();i++){
......@@ -344,7 +438,7 @@ int main ( int argc, char **argv ){
}
}
else{
cerr<<"Seed Space Coordinate File Not present - Ignoring"<<endl;
cout<<"Seed Space Coordinate File Not present - Ignoring"<<endl;
}
//Checking For and Loading Up Tract coordinates
......@@ -360,7 +454,7 @@ int main ( int argc, char **argv ){
}
}
else{
cerr<<"Tract Space Coordinate File Not present - Ignoring"<<endl;
cout<<"Tract Space Coordinate File Not present - Ignoring"<<endl;
}
......@@ -426,14 +520,14 @@ int main ( int argc, char **argv ){
cerr<<"Computing correlation"<<endl;
cout<<"Computing correlation"<<endl;
SymmetricMatrix CtCt;
CtCt << corrcoef(newOMmat.t());
CtCt << CtCt+1;
// adding connexity constraint
if(!coordbool){
cerr<<"WARNING !! No coordinates provided. I cannot apply any connexity constraint."<<endl;
cout<<"WARNING !! No coordinates provided. I cannot apply any connexity constraint."<<endl;
}
else{
add_connexity(CtCt,newcoordmat,opts.connexity.value());
......@@ -466,11 +560,18 @@ int main ( int argc, char **argv ){
}
else{
cerr<<"Starting First Reordering"<<endl;
spect_reord(CtCt,r1,y1);
cout<<"Starting First Reordering"<<endl;
if(opts.scheme.value()=="spectral")
spect_reord(CtCt,r1,y1);
else if(opts.scheme.value()=="kmeans")
kmeans_reord(CtCt,r1,y1,opts.nclusters.value());
else{
cerr << "unkown reordering scheme" << endl;
return(-1);
}
cerr<<"Permuting seed CC matrix"<<endl;
cout<<"Permuting seed CC matrix"<<endl;
for(int j=0;j<outCCvol.ysize();j++){
for(int i=0;i<outCCvol.xsize();i++){
outCCvol(i,j,0)=CtCt((int)r1(i+1),(int)r1(j+1));
......@@ -479,7 +580,7 @@ int main ( int argc, char **argv ){
if(coordbool){
cerr<<"Permuting Seed Coordinates"<<endl;
cout<<"Permuting Seed Coordinates"<<endl;
for(int i=0;i<outcoords.xsize();i++){
outcoords(i,0,0)=(int)newcoordmat(int(r1(i+1)),1);
outcoords(i,1,0)=(int)newcoordmat(int(r1(i+1)),2);
......@@ -491,15 +592,60 @@ int main ( int argc, char **argv ){
write_ascii_matrix(y1,base+"y1");
save_volume(outCCvol,"reord_CC_"+base);
save_volume(outcoords,"coords_for_reord_"+base);
// save clustering if kmeans used
if(opts.scheme.value() == "kmeans"){
volume<int> mask;
read_volume(mask,opts.mask.value());
mask = 0;
for(int i=0;i<outcoords.xsize();i++){
mask(outcoords(i,0,0),
outcoords(i,1,0),
outcoords(i,2,0)) = (int)y1(i+1) + 1;
}
save_volume(mask,"reord_mask_"+base);
// save tractspace clustering if specified
volume<int> outmask,tractmask;
read_volume(tractmask,"lookup_tractspace_fdt_matrix2");
outmask=tractmask;
copybasicproperties(tractmask,outmask);
outmask=0;
for(int z=0;z<tractmask.zsize();z++)
for(int y=0;y<tractmask.ysize();y++)
for(int x=0;x<tractmask.xsize();x++){
int j=tractmask(x,y,z);
ColumnVector vals(myOM.xsize());
for(int i=0;i<myOM.xsize();i++){
vals(i+1) = myOM(i,j,0);
}
if(vals.MaximumAbsoluteValue()==0)continue;
int index;
vals.Maximum1(index);
outmask(x,y,z) = (int)y1(index);
}
save_volume(outmask,"tract_clustering_"+base);
}
}
if(opts.reord2.value()){
cerr<<"Starting Second Reordering"<<endl;
cout<<"Starting Second Reordering"<<endl;
SymmetricMatrix CC;
CC << corrcoef(newOMmat);
CC<<CC+1;
spect_reord(CC,r2,y2);
if(opts.scheme.value()=="spectral")
spect_reord(CC,r2,y2);
else if(opts.scheme.value()=="kmeans")
kmeans_reord(CC,r2,y2,opts.nclusters.value());
else{
cerr << "unkown reordering scheme" << endl;
return(-1);
}
write_ascii_matrix(r2,base+"r2");
write_ascii_matrix(y2,base+"y2");
......@@ -508,7 +654,7 @@ int main ( int argc, char **argv ){
volume<int> outtractcoords(newtractcoordmat.Nrows(),3,1);
cerr<<"Permuting Matrix"<<endl;
cout<<"Permuting Matrix"<<endl;
for(int j=0;j<outvol.ysize();j++){
for(int i=0;i<outvol.xsize();i++){
outvol(i,j,0)=(int)newOMmat((int)r1(i+1),(int)r2(j+1));
......@@ -516,7 +662,7 @@ int main ( int argc, char **argv ){
}
if(tractcoordbool){
cerr<<"Permuting Tract Coordinates"<<endl;
cout<<"Permuting Tract Coordinates"<<endl;
for(int i=0;i<outtractcoords.xsize();i++){
outtractcoords(i,0,0)=(int)newtractcoordmat(int(r2(i+1)),1);
outtractcoords(i,1,0)=(int)newtractcoordmat(int(r2(i+1)),2);
......
......@@ -35,7 +35,7 @@ class ccopsOptions {
Option<float> power;
Option<string> mask;
Option<string> scheme;
Option<int> kmeans;
Option<int> nclusters;
bool parse_command_line(int argc, char** argv);
private:
......@@ -88,9 +88,9 @@ class ccopsOptions {
string("brain mask used to output the clustered roi mask"),
false, requires_argument),
scheme(string("-s,--scheme"), "spectral",
string("Reordering algorithm. Can be either spectral (default) or dpm or kmeans"),
string("Reordering algorithm. Can be either spectral (default) or kmeans"),
false, requires_argument),
kmeans(string("-K"), 2,
nclusters(string("-k,--nclusters"), 2,
string("Number of clusters to be used in kmeans"),
false, requires_argument),
options("ccops","")
......@@ -109,7 +109,7 @@ class ccopsOptions {
options.add(power);
options.add(mask);
options.add(scheme);
options.add(kmeans);
options.add(nclusters);
}
catch(X_OptionError& e) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment