-
Stephen Smith authoredStephen Smith authored
ols.cc 5.33 KiB
/* ols.cc
Mark Woolrich, FMRIB Image Analysis Group
Copyright (C) 1999-2000 University of Oxford */
/* CCOPYRIGHT */
#include "ols.h"
#include "miscmaths.h"
#include "sigproc.h"
#include "Log.h"
#ifndef NO_NAMESPACE
using namespace MISCMATHS;
namespace SIGPROC {
#endif
Ols::Ols(const Matrix& p_y, const Matrix& p_x, const Matrix& p_contrasts):
y(p_y),
x(p_x),
contrasts(p_contrasts),
numTS(p_y.Ncols()),
sizeTS(p_y.Nrows()),
r(sizeTS,numTS),
pinv_x(p_x.Ncols(), sizeTS),
var_on_e(0.0),
cb(numTS),
b(p_x.Ncols()),
var(numTS),
dof(sizeTS - p_x.Ncols()),
V(sizeTS,sizeTS),
RV(sizeTS,sizeTS),
RMat(sizeTS,sizeTS),
batch_size(BATCHSIZE)
{
SetContrast(1);
}
void Ols::ConstructV(const ColumnVector& p_vrow)
{
Tracer ts("ConstructV");
V = 0;
for (int i = 1; i <= sizeTS; i++)
{
V.SubMatrix(i,i,i,sizeTS) = p_vrow.Rows(1,sizeTS-i+1).t();
V.SubMatrix(i,i,1,i) = p_vrow.Rows(1,i).Reverse().t();
}
}
float Ols::SetupWithV(const ColumnVector& p_vrow, bool p_justvarone)
{
Tracer ts("SetupWithV");
ConstructV(p_vrow);
// var/e = c'inv(x'x)x'*V*x*inv(x'x)*c
var_on_e = (c.t()*pinv_x*V*x*(x.t()*x).i()*c).AsScalar();
if(!p_justvarone)
{
// dof = 2*trace(RV)^2/trace(R*V*R*V);
RV = RMat*V;
dof = Trace(RV)*Trace(RV)/Trace(RV*RV);
}
return dof;
}
float Ols::SetupWithV(const ColumnVector& p_vrow, const ColumnVector& p_kfft, ColumnVector& vrow, bool p_justvarone, const int zeropad)
{
Tracer ts("SetupWithV");
// make sure p_vrow is cyclic (even function)
//ColumnVector vrow(zeropad);
vrow.ReSize(zeropad);
vrow = 0;
vrow.Rows(1,sizeTS/2) = p_vrow.Rows(1,sizeTS/2);
vrow.Rows(zeropad - sizeTS/2 + 2, zeropad) = p_vrow.Rows(2, sizeTS/2).Reverse();
// fft vrow
ColumnVector fft_real;
ColumnVector fft_im;
ColumnVector dummy(zeropad);
dummy = 0;
ColumnVector realifft(zeropad);
FFT(vrow, dummy, fft_real, fft_im);
FFTI(SP(fft_real, p_kfft), dummy, realifft, dummy);
vrow = realifft.Rows(1,sizeTS);
// Normalise vrow:
vrow = vrow/vrow(1);
ConstructV(vrow);
// var/e = c'inv(x'x)x'*V*x*inv(x'x)*c
var_on_e = (c.t()*pinv_x*V*x*(x.t()*x).i()*c).AsScalar();
if(!p_justvarone)
{
// dof = 2*trace(RV)^2/trace(R*V*R*V);
RV = RMat*V;
dof = Trace(RV)*Trace(RV)/Trace(RV*RV);
}
return dof;
}
const Matrix& Ols::ComputeResids()
{
Tracer ts("ComputeResids");
int batch_pos = 1;
// pinv(x) = inv(x'x)x'
pinv_x = (x.t()*x).i()*x.t();
// R = I - x*pinv(x)
Matrix I(sizeTS, sizeTS);
Identity(I);
RMat = I - x*pinv_x;
while(batch_pos <= numTS)
{
if(batch_pos+batch_size - 1 > numTS)
r.Columns(batch_pos, numTS) = RMat*y.Columns(batch_pos, numTS);
else
r.Columns(batch_pos, batch_pos+batch_size-1) = RMat*y.Columns(batch_pos, batch_pos+batch_size-1);
batch_pos += batch_size;
}
return r;
}
const ColumnVector& Ols::Computecb()
{
Tracer ts("Computecb");
// cerr << "Computing cbs";
int batch_pos = 1;
while(batch_pos <= numTS)
{
if(batch_pos+batch_size - 1 > numTS)
cb.Rows(batch_pos, numTS) = (c.t()*pinv_x*y.Columns(batch_pos, numTS)).t();
else
cb.Rows(batch_pos, batch_pos+batch_size-1) = (c.t()*pinv_x*y.Columns(batch_pos, batch_pos+batch_size-1)).t();
batch_pos += batch_size;
// cerr << ".";
}
//cerr << endl;
return cb;
}
const float Ols::Computecb(const int ind)
{
Tracer ts("Computecb");
cb(ind) = ((c.t()*pinv_x*y.Column(ind)).t()).AsScalar();
return cb(ind);
}
const ColumnVector& Ols::Computeb(const int ind)
{
Tracer ts("Computeb");
b = pinv_x*y.Column(ind);
return b;
}
const ColumnVector& Ols::ComputeVar()
{
Tracer ts("ComputeVar");
// cerr << "Computing Vars";
int batch_pos = 1;
Matrix varmatfull(batch_size, batch_size);
ColumnVector vartempfull(batch_size);
while(batch_pos <= numTS)
{
if(batch_pos+batch_size - 1 > numTS)
{
// var = e*var_on_e
// e is the estimate of the variance of the timeseries, sigma^2
Matrix varmat = (r.Columns(batch_pos, numTS).t()*r.Columns(batch_pos, numTS)/Trace(RV))*var_on_e;
ColumnVector vartemp;
getdiag(vartemp, varmat);
var.Rows(batch_pos, numTS) = vartemp;
}
else
{
varmatfull = (r.Columns(batch_pos, batch_pos+batch_size-1).t()*r.Columns(batch_pos, batch_pos+batch_size-1)/Trace(RV))*var_on_e;
getdiag(vartempfull, varmatfull);
var.Rows(batch_pos, batch_pos+batch_size-1) = vartempfull;
}
batch_pos += batch_size;
// cerr << ".";
}
// cerr << endl;
return var;
}
const float Ols::ComputeVar(const int ind)
{
Tracer ts("ComputeVar");
// var = e*var_on_e
// e is the estimate of the variance of the timeseries, sigma^2
var(ind) = ((r.Column(ind).t()*r.Column(ind)/Trace(RV))*var_on_e).AsScalar();
// var(ind) = (r.Column(ind).t()*r.Column(ind)*var_on_e).AsScalar();
return var(ind);
}
#ifndef NO_NAMESPACE
}
#endif