Skip to content
Snippets Groups Projects
ols.cc 5.33 KiB
/*  ols.cc

    Mark Woolrich, FMRIB Image Analysis Group

    Copyright (C) 1999-2000 University of Oxford  */

/*  CCOPYRIGHT  */

#include "ols.h"
#include "miscmaths.h"
#include "sigproc.h"
#include "Log.h"

#ifndef NO_NAMESPACE
using namespace MISCMATHS;
namespace SIGPROC {
#endif

  Ols::Ols(const Matrix& p_y, const Matrix& p_x, const Matrix& p_contrasts):
    y(p_y),
    x(p_x),
    contrasts(p_contrasts),
    numTS(p_y.Ncols()),
    sizeTS(p_y.Nrows()),
    r(sizeTS,numTS),
    pinv_x(p_x.Ncols(), sizeTS),
    var_on_e(0.0),
    cb(numTS),
    b(p_x.Ncols()),
    var(numTS),
    dof(sizeTS - p_x.Ncols()),
    V(sizeTS,sizeTS),
    RV(sizeTS,sizeTS),
    RMat(sizeTS,sizeTS),
    batch_size(BATCHSIZE)
    {
      SetContrast(1);
    }
  
  void Ols::ConstructV(const ColumnVector& p_vrow)
    {
      Tracer ts("ConstructV");
      V = 0;

      for (int i = 1; i <= sizeTS; i++)
       {
	 V.SubMatrix(i,i,i,sizeTS) = p_vrow.Rows(1,sizeTS-i+1).t();
	 V.SubMatrix(i,i,1,i) = p_vrow.Rows(1,i).Reverse().t();
       }
    }

  float Ols::SetupWithV(const ColumnVector& p_vrow, bool p_justvarone)
    {
      Tracer ts("SetupWithV");
      
      ConstructV(p_vrow);

      // var/e = c'inv(x'x)x'*V*x*inv(x'x)*c
      var_on_e = (c.t()*pinv_x*V*x*(x.t()*x).i()*c).AsScalar();

      if(!p_justvarone)
	{
	  // dof = 2*trace(RV)^2/trace(R*V*R*V);
	  RV = RMat*V;
	  dof = Trace(RV)*Trace(RV)/Trace(RV*RV);
	}

      return dof;
    }
  float Ols::SetupWithV(const ColumnVector& p_vrow, const ColumnVector& p_kfft, ColumnVector& vrow, bool p_justvarone, const int zeropad)
    {
      Tracer ts("SetupWithV");

      // make sure p_vrow is cyclic (even function)
      //ColumnVector vrow(zeropad);
      vrow.ReSize(zeropad);

      vrow = 0;
      vrow.Rows(1,sizeTS/2) = p_vrow.Rows(1,sizeTS/2);
      vrow.Rows(zeropad - sizeTS/2 + 2, zeropad) = p_vrow.Rows(2, sizeTS/2).Reverse();

      // fft vrow
      ColumnVector fft_real;
      ColumnVector fft_im;
      ColumnVector dummy(zeropad);
      dummy = 0;
      
      ColumnVector realifft(zeropad);

      FFT(vrow, dummy, fft_real, fft_im);

      FFTI(SP(fft_real, p_kfft), dummy, realifft, dummy);
      
      vrow = realifft.Rows(1,sizeTS);

      // Normalise vrow:
      vrow = vrow/vrow(1);

      ConstructV(vrow);     

      // var/e = c'inv(x'x)x'*V*x*inv(x'x)*c
      var_on_e = (c.t()*pinv_x*V*x*(x.t()*x).i()*c).AsScalar();
      
      if(!p_justvarone)
	{
	  // dof = 2*trace(RV)^2/trace(R*V*R*V);
	  RV = RMat*V;
	  dof = Trace(RV)*Trace(RV)/Trace(RV*RV);
	}

      return dof;
    }

  const Matrix& Ols::ComputeResids()
    {
      Tracer ts("ComputeResids");

      int batch_pos = 1;

      // pinv(x) = inv(x'x)x'
      pinv_x = (x.t()*x).i()*x.t();

      // R = I - x*pinv(x)
      Matrix I(sizeTS, sizeTS);
      Identity(I);

      RMat = I - x*pinv_x;
      
      while(batch_pos <= numTS)
	{
	  if(batch_pos+batch_size - 1 > numTS)
	    r.Columns(batch_pos, numTS) = RMat*y.Columns(batch_pos, numTS);
	  else
	    r.Columns(batch_pos, batch_pos+batch_size-1) = RMat*y.Columns(batch_pos, batch_pos+batch_size-1);
	
	  batch_pos += batch_size;
	}
      
      return r;
    }

  const ColumnVector& Ols::Computecb()
    { 
      Tracer ts("Computecb");
      
      //     cerr << "Computing cbs";
      int batch_pos = 1;
      
      while(batch_pos <= numTS)
	{
	  if(batch_pos+batch_size - 1 > numTS)
	    cb.Rows(batch_pos, numTS) = (c.t()*pinv_x*y.Columns(batch_pos, numTS)).t();
	  else
	    cb.Rows(batch_pos, batch_pos+batch_size-1) = (c.t()*pinv_x*y.Columns(batch_pos, batch_pos+batch_size-1)).t();
	  batch_pos += batch_size;
	  //	  cerr << ".";
	}
      //cerr << endl;
      return cb;
    }

  const float Ols::Computecb(const int ind)
    { 
      Tracer ts("Computecb");

      cb(ind) = ((c.t()*pinv_x*y.Column(ind)).t()).AsScalar();
      return cb(ind);
    }

  const ColumnVector& Ols::Computeb(const int ind)
    { 
      Tracer ts("Computeb");
      
      b = pinv_x*y.Column(ind);
      return b;
    }

  const ColumnVector& Ols::ComputeVar()
    { 
      Tracer ts("ComputeVar");

      //      cerr << "Computing Vars";
      int batch_pos = 1;

      Matrix varmatfull(batch_size, batch_size);
      ColumnVector vartempfull(batch_size);

      while(batch_pos <= numTS)
	{
	  if(batch_pos+batch_size - 1 > numTS)
	    {
	      // var = e*var_on_e
	      // e is the estimate of the variance of the timeseries, sigma^2
	      Matrix varmat = (r.Columns(batch_pos, numTS).t()*r.Columns(batch_pos, numTS)/Trace(RV))*var_on_e;
	      ColumnVector vartemp;
	      getdiag(vartemp, varmat);
	      var.Rows(batch_pos, numTS) = vartemp;
	    }      
	  else
	    {
	      varmatfull = (r.Columns(batch_pos, batch_pos+batch_size-1).t()*r.Columns(batch_pos, batch_pos+batch_size-1)/Trace(RV))*var_on_e;
	      getdiag(vartempfull, varmatfull);
	      var.Rows(batch_pos, batch_pos+batch_size-1) = vartempfull;
	    }
	  batch_pos += batch_size;
	  //  cerr << ".";
	}
      
      // cerr << endl;
      return var;
    }

  const float Ols::ComputeVar(const int ind)
    { 
      Tracer ts("ComputeVar");
   
      // var = e*var_on_e
      // e is the estimate of the variance of the timeseries, sigma^2
      var(ind) = ((r.Column(ind).t()*r.Column(ind)/Trace(RV))*var_on_e).AsScalar();
      // var(ind) = (r.Column(ind).t()*r.Column(ind)*var_on_e).AsScalar();

      return var(ind);
    }

#ifndef NO_NAMESPACE
}
#endif