/*  MELODIC - Multivariate exploratory linear optimized decomposition into 
              independent components
    
    melodic.cc - main program file

    Christian F. Beckmann, FMRIB Image Analysis Group
    
    Copyright (C) 1999-2002 University of Oxford */

/*  CCOPYRIGHT  */

#include "newmat/newmatap.h"
#include "newmat/newmatio.h"
#include "newimage/newimageall.h"
#include "miscmaths/miscmaths.h"
#include "miscmaths/miscprob.h"
#include "utils/options.h"
#include "utils/log.h"
#include "meloptions.h"
#include "meldata.h"
#include "melpca.h"
#include "melica.h"
#include "melodic.h"
#include "melreport.h"
#include "melgmix.h"

using namespace Utilities;
using namespace NEWMAT;
using namespace NEWIMAGE;
using namespace Melodic;
using namespace MISCPLOT;

string myfloat2str(float f, int width, int prec, bool scientif)
  {
    ostrstream os;
    int redw = int(std::abs(std::log10(std::abs(f))))+1;
    if(width>0)
      os.width(width);
    if(scientif)
      os.setf(ios::scientific);
    os.precision(redw+std::abs(prec));
    os.setf(ios::internal, ios::adjustfield);
    os << f << '\0';
    return os.str();
  }

Matrix mmall(Log& logger, MelodicOptions& opts,
	     MelodicData& melodat, MelodicReport& report, Matrix& probs);
void repall(Log& logger, MelodicOptions& opts, MelodicData& melodat, 
	    MelodicReport& report, Matrix mmres, Matrix probs);
void mmonly(Log& logger, MelodicOptions& opts,
	    MelodicData& melodat, MelodicReport& report);

int main(int argc, char *argv[])
{

  try{
    // Setup logging:
    Log& logger  =   LogSingleton::getInstance();
 
    // parse command line - will output arguments to logfile
    MelodicOptions& opts = MelodicOptions::getInstance();
    opts.parse_command_line(argc, argv, logger, Melodic::version); 

    //set up data object
    MelodicData melodat(opts,logger);
      
    MelodicReport report(melodat,opts,logger);
   
    if (opts.filtermode || opts.filtermix.value().length()>0){
      if(opts.filtermode){ // just filter out some noise from a previous run
	melodat.setup();
	melodat.remove_components();
      }
      else
	mmonly(logger,opts,melodat,report);
    } 
    else
    {  // standard PICA now
      int retry = 0;
      bool no_conv;
      bool leaveloop = false;

      melodat.setup();

      do{
	//do PCA pre-processing
	MelodicPCA pcaobj(melodat,opts,logger,report);
	pcaobj.perf_pca(melodat.get_DataVN());
	pcaobj.perf_white(melodat.get_Data());

	//do ICA
	MelodicICA icaobj(melodat,opts,logger,report);
	icaobj.perf_ica(melodat.get_white()*melodat.get_Data());
    
	no_conv = icaobj.no_convergence;

	opts.maxNumItt.set_T(500);
	if((opts.approach.value()=="symm")&&(retry > std::min(opts.retrystep,3))){
	  if(no_conv){
	    retry++;
	    opts.approach.set_T("defl");
	    message(endl << "Restarting MELODIC using deflation approach" 
		    << endl << endl);
	  }
	  else{
	    leaveloop = true;
	  }
	}
	else{
	  if(no_conv){
	    retry++;
	    if(opts.pca_dim.value()-retry*opts.retrystep > 
	       0.1*melodat.data_dim()){
	      opts.pca_dim.set_T(opts.pca_dim.value()-retry*opts.retrystep);
	    }
	    else{
	      if(opts.pca_dim.value()-retry*opts.retrystep <  melodat.data_dim()){
		opts.pca_dim.set_T(opts.pca_dim.value()+retry*opts.retrystep);
	      }else{
		leaveloop = true; //stupid, but break does not compile 
		                  //on all platforms
	      }
	    }
	    if(!leaveloop){
	      message(endl << "Restarting MELODIC using -d " 
		      << opts.pca_dim.value() 
		      << endl << endl);
	    }
	  }
	}
      } while (no_conv && retry<opts.maxRestart.value() && !leaveloop);	
     
      if(!no_conv){
	//first save raw IC results
	melodat.save();

	Matrix pmaps;//(melodat.get_IC());
	Matrix mmres;

	if(opts.perf_mm.value())
	  mmres = mmall(logger,opts,melodat,report,pmaps);
	else{
	  if( bool(opts.genreport.value()) ){
	    message(endl 
		    << "Creating web report in " << report.getDir() 
		    << " " << endl);
	    for(int ctr=1; ctr<= melodat.get_IC().Nrows(); ctr++){
	      string prefix = "IC_"+num2str(ctr);
	      message("  " << ctr);
	      report.IC_simplerep(prefix,ctr,melodat.get_IC().Nrows());
	    }

	    
	    
	    message(endl << endl <<
		    " To view the output report point your web browser at " <<
		    report.getDir() + "/00index.html" << endl<< endl); 
	  }
	}		 

	if( bool(opts.genreport.value()) ){
	  report.analysistxt();
	  report.PPCA_rep();
	}
	//cerr << mmres.Nrows() << " x " << mmres.Ncols() << endl;

	//  if(opts.genreport.value())
	//  repall(logger,opts,melodat,report,mmres,pmaps,threshmaps);

	message("finished!" << endl << endl);
      } else { 
	message(endl <<"No convergence -- giving up " << endl <<
		"please contact fsl@fmrib.ox.ac.uk " << endl);
      }	     
    }
  }
  catch(Exception e) 
    {
      cerr << endl << e.what() << endl;
    }
  catch(X_OptionError& e) 
    {
      cerr << endl << e.what() << endl;
    }

  return 0;
}

void mmonly(Log& logger, MelodicOptions& opts,
	   MelodicData& melodat, MelodicReport& report){

  Matrix ICs;
  Matrix mixMatrix;
  Matrix fmixMatrix;
  volumeinfo ICvolInfo;
  volume<float> Mask;
  volume<float> Mean;
  
  {
    volume4D<float> RawData;
    message("Reading data file " << opts.inputfname.value() << "  ... ");
    read_volume4D(RawData,opts.inputfname.value(),ICvolInfo);
    message(" done" << endl);
    Mean = meanvol(RawData);
  }

  {
    volume4D<float> RawIC;
    message("Reading components " << opts.ICsfname.value() << "  ... ");
    read_volume4D(RawIC,opts.ICsfname.value());
    message(" done" << endl);

    message("Creating mask   ... ");
    Mask = binarise(RawIC[0],float(RawIC[0].min()),float(RawIC[0].max()));

    ICs = RawIC.matrix(Mask);
    if(ICs.Nrows()>1){
      Matrix DStDev=stdev(ICs);
      
      volume4D<float> tmpMask;
      tmpMask.setmatrix(DStDev,Mask);
      
      float tMmax;
      volume<float> tmpMask2;
      tmpMask2 = tmpMask[0];
      tMmax = tmpMask2.max();
      double st_mean = DStDev.Sum()/DStDev.Ncols();
      double st_std  = stdev(DStDev.t()).AsScalar();
      
      Mask = binarise(tmpMask2,(float) max((float) st_mean-3*st_std,
					   (float) 0.01*st_mean),tMmax);  
      ICs = RawIC.matrix(Mask);
    }
    else{
      Mask = binarise(RawIC[0],float(0.001),float(RawIC[0].max())) 
	+ binarise(RawIC[0],float(RawIC[0].min()),float(-0.001));
      ICs = RawIC.matrix(Mask);
    }

    //cerr << "ICs : " << ICs.Ncols() << ICs.Nrows() << endl;
    message(" done" << endl);
  }

  message("Reading mixing matrix " << opts.filtermix.value());
  mixMatrix = read_ascii_matrix(opts.filtermix.value());
  if (mixMatrix.Storage()<=0) {
    cerr <<" Please specify the mixing matrix correctly" << endl;
    exit(2);
  }
  message(" done" << endl);

  melodat.tempInfo = ICvolInfo;
  melodat.set_mask(Mask);
  melodat.set_mean(Mean);
  melodat.set_IC(ICs);
  melodat.set_mix(mixMatrix);
  fmixMatrix = melodat.calc_FFT(mixMatrix);
  melodat.set_fmix(fmixMatrix);
  fmixMatrix = pinv(mixMatrix);
  melodat.set_unmix(fmixMatrix);

  //  write_ascii_matrix("ICs",ICs);
  
  Matrix mmres;
  Matrix pmaps;//(ICs);
  if(opts.perf_mm.value())
    mmres = mmall(logger,opts,melodat,report,pmaps);
}

void repall(Log& logger, MelodicOptions& opts, MelodicData& melodat, 
	    MelodicReport& report, Matrix& mmpars, Matrix& probs)
{
  if( bool(opts.genreport.value()) ){

    message(endl 
	    << "Creating report in " << report.getDir() 
	    << endl);

    for(int ctr=1; ctr<=probs.Nrows(); ctr++){
      message("  " << ctr);
      MelGMix mixmod(opts, logger);
      
      //load MelGMix
      
      //report.IC_rep(mixmod,ctr,melodat.get_IC().Nrows());
  
    }  
  }
}

Matrix mmall(Log& logger, MelodicOptions& opts,
	   MelodicData& melodat, MelodicReport& report, Matrix& pmaps)
{
  
  Matrix mmpars(5*melodat.get_IC().Nrows(),5);
  mmpars = 0;
  //Matrix pmaps(melodat.get_IC());
  
  Log stats;
  
  if(opts.output_MMstats.value()){  
    stats.makeDir(logger.appendDir("stats"),"stats.log");
  }

  message(endl 
	  << "Running Mixture Modelling on Z-transformed IC maps ..." 
	  << endl);

  for(int ctr=1; ctr <= melodat.get_IC().Nrows(); ctr++){
    MelGMix mixmod(opts, logger);
    
    message("  IC map " << ctr << " ... "<< endl;);
    
    Matrix ICmap;

    if(melodat.get_stdNoisei().Storage()>0)
      ICmap = SP(melodat.get_IC().Row(ctr),melodat.get_stdNoisei());
    else
      ICmap = melodat.get_IC().Row(ctr);

    string wherelog;
    if(opts.genreport.value())
      wherelog = report.getDir();
    else
      wherelog = logger.getDir();

    mixmod.setup( ICmap, melodat.tempInfo,
		  wherelog,ctr,
		  melodat.get_mask(), 
		  melodat.get_mean(),3);
    mixmod.fit("GGM");

    if(opts.output_MMstats.value()){
      melodat.save4D(mixmod.get_probmap(),
		    string("stats/probmap_")+num2str(ctr));
    }
    // save probability map
    //if((mixmod.get_probmap().Storage()>0)&&
    //   (mixmod.get_probmap().Ncols() == pmaps.Ncols()))
    // pmaps.Row(ctr) = mixmod.get_probmap();
    //else
    //  pmaps.Row(ctr) = zeros(1,pmaps.Ncols());

    message("   thresholding ... "<< endl);
    mixmod.threshold(opts.mmthresh.value());  

    //re-orient the data  
    
    //message("   done " << endl);
    Matrix tmp;
    tmp=(mixmod.get_threshmaps().Row(1));
    float posint = SP(tmp,gt(tmp,zeros(1,tmp.Ncols()))).Sum();
    float negint = -SP(tmp,lt(tmp,zeros(1,tmp.Ncols()))).Sum();
   
    //cerr << posint << "  " << negint << endl;
    
    if((posint<0.01)&&(negint<0.01)){
      mixmod.clear_infstr();
      //cerr << "after infstr"<<endl;
      mixmod.threshold("0.05n "+opts.mmthresh.value());
      //cerr << " back again" << endl;
      posint = SP(tmp,gt(tmp,zeros(1,tmp.Ncols()))).Sum();
      negint = -SP(tmp,lt(tmp,zeros(1,tmp.Ncols()))).Sum();
    }
    //cerr << posint << "  " << negint << endl;
    if(negint>posint){//flip map
      melodat.flipres(ctr);
      mixmod.flipres(ctr);
    }

    //save mixture model stats 
    if(opts.output_MMstats.value()){
      stats << " IC " << num2str(ctr) << " " << mixmod.get_type() << endl
	    << " Means :  " << mixmod.get_means() << endl
	    << " Vars. :  " << mixmod.get_vars()  << endl
	    << " Prop. :  " << mixmod.get_pi()    << endl << endl;
      //  << " Offs. :  " << mixmod.get_offset() << endl << endl;
      //cerr << mixmod.get_threshmaps().Nrows() << " " << mixmod.get_threshmaps().Ncols()<< endl;
      melodat.save4D(mixmod.get_threshmaps(),
		     string("stats/thresh_zstat")+num2str(ctr));
    }

    //save mmpars
    // mmpars((ctr-1)*5+1,1) = ctr;
//     if(mixmod.get_type()=="GGM")
//       mmpars((ctr-1)*5+1,2) = 1.0;
//     else
//       mmpars((ctr-1)*5+1,2) = 0.0;
//     mmpars((ctr-1)*5+1,2) = mixmod.get_means().Ncols();
//     tmp =  mixmod.get_means();
//     for(int ctr2=1;ctr2<=mixmod.get_means().Ncols();ctr2++)
//       mmpars((ctr-1)*5+2,ctr2) = tmp(1,ctr2);
//     tmp =  mixmod.get_vars();
//     for(int ctr2=1;ctr2<=mixmod.get_vars().Ncols();ctr2++)
//       mmpars((ctr-1)*5+3,ctr2) = tmp(1,ctr2);
//     tmp =  mixmod.get_pi(); 
//     for(int ctr2=1;ctr2<=mixmod.get_pi().Ncols();ctr2++)
//       mmpars((ctr-1)*5+4,ctr2) = tmp(1,ctr2);
//     mmpars((ctr-1)*5+5,1) = mixmod.get_offset();

 

    if( bool(opts.genreport.value()) ){
      message("   creating report page ... ");
      report.IC_rep(mixmod,ctr,melodat.get_IC().Nrows());	    
      message("   done" << endl);
    }
  }
  if( bool(opts.genreport.value()) ){
    message(endl << endl << 
	    " To view the output report point your web browser at " <<
	    report.getDir() + "/00index.html" << endl << endl); 
  }
  if(!opts.filtermode&&opts.filtermix.value().length()==0){
    //now safe new data
    bool what = opts.verbose.value();
    opts.verbose.set_T(false);
    melodat.save();
    opts.verbose.set_T(what);
  }
  return mmpars;
}