Skip to content
Snippets Groups Projects
fsldecorr4d.cc 9.01 KiB
/*  fsldecorr4d.cc

    Mark Jenkinson, FMRIB Image Analysis Group

    Copyright (C) 2007 University of Oxford  */

/*  CCOPYRIGHT  */


// Decorrelates a set of separable 4D components from
//  a 4D dataset

#include "newimage/newimageall.h"
#include "miscmaths/miscmaths.h"
#include "utils/options.h"

using namespace NEWIMAGE;
using namespace MISCMATHS;

using namespace Utilities;

// The two strings below specify the title and example usage that is
//  printed out as the help or usage message

string title="fsldecorr4d (Version 1.0)\nCopyright(c) 2004, University of Oxford (Mark Jenkinson)\nRemoves 4D components (by decorrelation) from a 4D dataset.\nEach component needs to be specified by a separate timecourse and spatial map.\nThe spatial maps are input as a single 4D file and the timecourses as a text matrix (with each column being a timecourse with the same ordering as the corresponding spatial maps.\n";
string examples="fsldecorr4d -s <spatial maps> -t <timecourses> -i <input> -o <output> [-m mask]";

// Each (global) object below specificies as option and can be accessed
//  anywhere in this file (since they are global).  The order of the
//  arguments needed is: name(s) of option, default value, help message,
//       whether it is compulsory, whether it requires arguments
// Note that they must also be included in the main() function or they
//  will not be active.

Option<bool> verbose(string("-v,--verbose"), false, 
		     string("switch on diagnostic messages"), 
		     false, no_argument);
Option<bool> help(string("-h,--help"), false,
		  string("display this message"),
		  false, no_argument);
Option<string> smapname(string("-s"), string(""),
		  string("~<filename>\tinput set of spatial maps (4D)"),
		  false, requires_argument);
Option<string> tcname(string("-t"), string(""),
		  string("~<filename>\tinput set of timecourses (text matrix)"),
		  false, requires_argument);
Option<string> maskname(string("-m"), string(""),
		  string("~<filename>\tinput 3D mask"),
		  false, requires_argument);
Option<string> outname(string("-o"), string(""),
		  string("~<filename>\toutput 4D dataset"),
		  true, requires_argument);
Option<string> inname(string("-i"), string(""),
		  string("~<filename>\tinput 4D dataset"),
		  true, requires_argument);

int nonoptarg;


int decorr1D(volume4D<float>& vol, const Matrix& tc, const volume<float>& mask)
{
  int nt=tc.Nrows();
  int nc=tc.Ncols();
  Matrix XX(nc,nc);
  ColumnVector Y(nt), Beta(nc);
  XX = tc.t() * tc;
  if (verbose.value()) { cout << "Processing slices "; }
  for (int z=vol[0].minz(); z<=vol[0].maxz(); z++) {
    for (int y=vol[0].miny(); y<=vol[0].maxy(); y++) {
      for (int x=vol[0].minx(); x<=vol[0].maxx(); x++) {
	if (mask(x,y,z)>0.5) {
	  for (int t=0; t<nt; t++) { Y(t+1) = vol(x,y,z,t); }
	  Beta = pinv(XX) * tc.t() * Y;
	  Y -= tc * Beta;
	  for (int t=0; t<nt; t++) { vol(x,y,z,t) = Y(t+1); }
	}
      }
    }
    if (verbose.value()) { cout << "."; }
  }
  if (verbose.value()) { cout << endl; }
  return 0;
}

int decorr4D(volume4D<float>& vol, const Matrix& tc, 
	     const volume4D<float>& smaps, const volume<float>& mask)
{
  int nt=tc.Nrows();
  int nc=tc.Ncols();

  // set up required matrices
  Matrix XX(nc,nc);
  ColumnVector XY(nc);
  XY=0.0;  XX=0.0;
    
  if (verbose.value()) { cout << "Calculating matrix values:" << endl; }
  // calculate matrix elements (4D correlations)
  for (int n=1; n<=nc; n++) {
    if (verbose.value()) { cout << "Component #" << n << endl; }
    for (int t=0; t<nt; t++) {
      for (int z=smaps[n-1].minz(); z<=smaps[n-1].maxz(); z++) {
	for (int y=smaps[n-1].miny(); y<=smaps[n-1].maxy(); y++) {
	  for (int x=smaps[n-1].minx(); x<=smaps[n-1].maxx(); x++) {
	    XY(n) += smaps(x,y,z,n-1) * tc(t+1,n) * vol(x,y,z,t);
	  }
	}
      }
    }
    for (int m=1; m<=n; m++) {
      volume<float> tmp;
      tmp = smaps[n-1] * smaps[m-1];
      XX(m,n) = tmp.sum();
      double tval=0.0;
      for (int t=0; t<nt; t++) {
	tval += tc(t+1,n) * tc(t+1,m);
      }
      XX(m,n) *= tval;
      XX(n,m) = XX(m,n);
    }
  }
  if (verbose.value()) { cout << "XY = " << XY.t()/nt << endl; }
  if (verbose.value()) { cout << "XX = " << (XX/nt)/nt << endl; }
    
    
  if (verbose.value()) { cout << "Finding amplitudes" << endl; }
  // find amplitudes for each component
  ColumnVector Beta(nc);
  Beta = pinv(XX)*XY;
  if (verbose.value()) { cout << "Amplitudes = " << Beta.t() << endl << endl; }
    
    
  if (verbose.value()) { cout << "Removing components" << endl; }
  // remove components from input data
  for (int n=1; n<=nc; n++) {
    for (int t=0; t<nt; t++) {
      for (int z=smaps[n-1].minz(); z<=smaps[n-1].maxz(); z++) {
	for (int y=smaps[n-1].miny(); y<=smaps[n-1].maxy(); y++) {
	  for (int x=smaps[n-1].minx(); x<=smaps[n-1].maxx(); x++) {
	    vol(x,y,z,t) -= Beta(n) * smaps(x,y,z,n-1) * tc(t+1,n);
	  }
	}
      }
    }
  }

  return 0;
}


int decorr3D(volume4D<float>& vol, const volume4D<float>& smaps, 
	     const volume<float>& mask)
{
  int nt=vol.tsize();
  int nc=smaps.tsize();

  // set up required matrices
  Matrix XX(nc,nc);
  ColumnVector XY(nc);
  XY=0.0;  XX=0.0;
    
  // calculate matrix elements (4D correlations)
  for (int n=1; n<=nc; n++) {
    for (int m=1; m<=n; m++) {
      volume<float> tmp;
      tmp = smaps[n-1] * smaps[m-1];
      XX(m,n) = tmp.sum();
      XX(n,m) = XX(m,n);
    }
  }
  if (verbose.value()) { cout << "XX = " << (XX/nt)/nt << endl; }
  Matrix pinvXX;
  pinvXX=pinv(XX);

  if (verbose.value()) { cout << "Calculating matrix values:" << endl; }
  // calculate matrix elements (4D correlations)
  for (int t=0; t<nt; t++) {
    if (verbose.value()) { cout << "."; }
    XY=0.0;
    for (int n=1; n<=nc; n++) {
      for (int z=smaps[n-1].minz(); z<=smaps[n-1].maxz(); z++) {
	for (int y=smaps[n-1].miny(); y<=smaps[n-1].maxy(); y++) {
	  for (int x=smaps[n-1].minx(); x<=smaps[n-1].maxx(); x++) {
	    XY(n) += smaps(x,y,z,n-1) * vol(x,y,z,t);
	  }
	}
      }
    }
    // find amplitudes for each component
    ColumnVector Beta(nc);
    Beta = pinvXX*XY;
    // remove components from input data
    for (int n=1; n<=nc; n++) {
      for (int z=smaps[n-1].minz(); z<=smaps[n-1].maxz(); z++) {
	for (int y=smaps[n-1].miny(); y<=smaps[n-1].maxy(); y++) {
	  for (int x=smaps[n-1].minx(); x<=smaps[n-1].maxx(); x++) {
	    vol(x,y,z,t) -= Beta(n) * smaps(x,y,z,n-1);
	  }
	}
      }
    }
  }
  if (verbose.value()) { cout << endl; }
  
  return 0;
}



int do_work(int argc, char *argv[])
{
  volume4D<float> vin;
  read_volume4D(vin,inname.value());

  // ** MASK ** //

  volume<float> mask;
  if (maskname.set()) {
    read_volume(mask,maskname.value());
  } else {
    mask = vin[0];
    mask = 1.0;
  }

  if (!samesize(vin[0],mask)) {
    cerr << "ERROR: Mask and Input volumes have different (x,y,z) size." 
	 << endl;
    return 2;
  }

  mask.binarise(1e-8);  // arbitrary "0" threshold


  // ** TIME COURSES ** //

  Matrix tc;
  if (tcname.set()) {
    tc = read_ascii_matrix(tcname.value());
    if (tc.Nrows() != vin.tsize()) {
      cerr << "ERROR: Different number of timepoints in timecourse file and input volume." << endl;
      return 3;
    }
    // demean timecourses (to stop any correlation with the mean)
    tc = remmean(tc);
  }

  // ** SPATIAL MAPS ** //

  // if no smaps specified then go for simple 1D decorrelation
  volume4D<float> smaps;
  if (smapname.set()) { 
    // if smaps are specified then do the full 4D decorrelation
    read_volume4D(smaps,smapname.value());
    if (!samesize(vin[0],smaps[0])) {
      cerr << "ERROR: Spatial maps and Input volumes have different (x,y,z) size." 
	   << endl;
      return 1;
    }
    if (tcname.set() && (tc.Ncols() != smaps.tsize())) {
      cerr << "ERROR: Different number of components for timecourses and spatial maps." << endl;
      return 4;
    }
    
    // mask spatial maps
    smaps *= mask;
    // demean spatial maps (wrt mask)
    for (int n=0; n<smaps.tsize(); n++) {
      smaps[n] -= ((float) smaps[n].mean())*mask;
    }
  }


  // ** DECORRELATION ** //

  int retval=0;
  if (smapname.set() && tcname.set()) {
    retval = decorr4D(vin,tc,smaps,mask);
  }
  if (!smapname.set() && tcname.set()) {
    retval = decorr1D(vin,tc,mask);
  }
  if (smapname.set() && !tcname.set()) {
    retval = decorr3D(vin,smaps,mask);
  }
  if (!smapname.set() && !tcname.set()) {
    cerr << "ERROR: Must pass in either spatial maps or timecourses or both"
	 << endl;
    return 3;
  }

  // save the result
  if (retval==0) save_volume4D(vin,outname.value());
  
  return retval;
}


int main(int argc,char *argv[])
{

  Tracer tr("main");
  OptionParser options(title, examples);

  options.add(inname);
  options.add(outname);
  options.add(smapname);
  options.add(tcname);
  options.add(maskname);
  options.add(verbose);
  options.add(help);
  
  nonoptarg = options.parse_command_line(argc, argv);
  
  // line below stops the program if the help was requested or 
  //  a compulsory option was not set
  if ( (help.value()) || (!options.check_compulsory_arguments(true)) )
    {
      options.usage();
      exit(EXIT_FAILURE);
    }
  

  // OK, now do the job ...
  return do_work(argc,argv);
}