/*  film_ols.cc

    Mark Woolrich, FMRIB Image Analysis Group

    Copyright (C) 1999-2000 University of Oxford  */

/*  CCOPYRIGHT  */

#include <iostream>
#include <fstream>
#include <sstream>
#define WANT_STREAM
#define WANT_MATH

#include "newmatap.h"
#include "newmatio.h"
#include "miscmaths/volumeseries.h"
#include "miscmaths/volume.h"
#include "glim.h"
#include "miscmaths/miscmaths.h"
#include "gaussComparer.h"
#include "utils/log.h"
#include "AutoCorrEstimator.h"
#include "paradigm.h"
#include "FilmOlsOptions.h"
#include <string>

using namespace NEWMAT;
using namespace FILM;
using namespace Utilities;

int main(int argc, char *argv[])
{
  try{
    rand();
    // parse command line to find out directory name for logging:
    ofstream out2;
    FilmOlsOptions& globalopts = FilmOlsOptions::getInstance();
    globalopts.parse_command_line(argc, argv, out2);
    
    // Setup logging:
    Log& logger = Log::getInstance();
    logger.makeDir(globalopts.datadir);

    // parse command line again to output arguments to logfile
    globalopts.parse_command_line(argc, argv, logger.str());

    // load non-temporally filtered data
    VolumeSeries x;
    x.read(globalopts.inputfname);

    // if needed output the 12th volume for use later
    Volume epivol;
    if(globalopts.smoothACEst)
      {
	epivol = x.getVolume(12).AsColumn();
	epivol.setDims(x.getDims());
	
	epivol.writeAsInt(logger.getDir() + "/" + globalopts.epifname);
      }

    // This also removes the mean from each of the time series:
    x.thresholdSeries(globalopts.thresh, true);

    // if needed later also threshold the epi volume
    if(globalopts.smoothACEst)
      {
	epivol.setPreThresholdPositions(x.getPreThresholdPositions());
	epivol.threshold();
      }

    int sizeTS = x.getNumVolumes();
    int numTS = x.getNumSeries();

    // Load paradigm: 
    Paradigm parad;
    parad.load(globalopts.paradigmfname, "", "",false, sizeTS);

    // Sort out detrending:
    if(globalopts.detrend)
      {
	// Do detrending separately as a preprocessing step:
	MISCMATHS::detrend(x, false);
      }
       
    if(globalopts.verbose)
      {
	logger.out("Gc", parad.getDesignMatrix());
      }

    // Setup OLS GLM for temporally filtered data:
    Glim glim(x, parad.getDesignMatrix());

    cerr << "Computing parameter estimates... ";
    const VolumeSeries& res = glim.ComputeResids();
    glim.ComputePes();
    x.Release();
    cerr << "Completed" << endl;

    if(globalopts.verbose)
      {
	logger.out("res", res);
      }

    ColumnVector meanACEstimate(sizeTS);
    AutoCorrEstimator acEst(res);

    if(!globalopts.noest)
      {
	// Estimate Autocorrelations:
	if(globalopts.fitAutoRegressiveModel)
	  {
	    acEst.fitAutoRegressiveModel();
	    if(globalopts.verbose)
	      {
		AutoCorrEstimator acEstForLogging(res);
		acEstForLogging.calcRaw();
		logger.out("rawac", acEstForLogging.getEstimates());
		logger.out("autoregac", acEst.getEstimates());
	      }
	    logger.out("autoregac", acEst.getEstimates());
	  }
	else
	  {
	    acEst.calcRaw();
	
	    if(globalopts.verbose)
	      {
		logger.out("rawac", acEst.getEstimates());
	      }
	
	    // Smooth raw estimates:
	    if(globalopts.smoothACEst)
	      {
		acEst.spatiallySmooth(logger.getDir() + "/" + globalopts.epifname, epivol, globalopts.ms, globalopts.epifname, 0);
	      }
	
	    // Apply constraints to estimate autocorr:
	    acEst.pava();
	
	    if(globalopts.verbose)
	      {
		logger.out("threshac", acEst.getEstimates());
	      }
	  }
	// get mean estimate
	acEst.getMeanEstimate(meanACEstimate);

      }
    else // no estimation of autocorrelations
      {
	acEst.getEstimates().ReSize(sizeTS, numTS);
	meanACEstimate = 0;
	meanACEstimate(1) = 1;        
      }

    // set global Vrow
    glim.SetGlobalVrow(meanACEstimate);

    if(globalopts.verbose)
      {
	logger.out("meanACEstimate", meanACEstimate);
      }
	
    if(!globalopts.globalEst && !globalopts.noest)
      {
	int co = 1;
	 
	// Loop through voxels calculating corrections:
	cerr << "Calculating auto correlation corrections for "  << numTS << " time series..." << endl;

	for(int i = 1; i <= numTS; i++)
	  {
	    // Put AutoCorr estimate into Glim
	    glim.SetVrow(acEst.getEstimates().getSeries(i),i);
	    glim.ComputeSigmaSquared(i);

	    // Log progress:
	    if(co > 100)
	      {
		cerr << i << ",";
		co = 1;
	      }
	    else
	      co++;
	  }	
	cerr << " Completed" << endl;
      }
    else
      {
	logger.out("globalvrow", meanACEstimate);
	glim.UseGlobalVrow();
	for(int i = 1; i <= numTS; i++)
	  {
	    glim.ComputeSigmaSquared(i);
	  }
      }

    // Write out necessary data:
    cerr << "Saving results... ";
    glim.Save();
    cerr << "Completed" << endl;
   
  }
  catch(Exception p_excp) 
    {
      cerr << p_excp.what() << endl;
    }
  catch(...) 
    {
      cerr << "Image error" << endl;
    } 
  return 0;
}