Skip to content
Snippets Groups Projects
Commit 6d36f028 authored by Christian Beckmann's avatar Christian Beckmann
Browse files

changed dimest restart heuristics

parent 8e697807
No related branches found
No related tags found
No related merge requests found
......@@ -228,6 +228,7 @@ namespace Melodic{
Smodes.clear();
bool tmpvarnorm = opts.varnorm.value();
// Switch off variance normalisation
opts.varnorm.set_T(false);
Log drO;
......@@ -235,34 +236,34 @@ namespace Melodic{
if(opts.dr_out.value())
drO.makeDir(logger.appendDir("dr"),"dr.log");
Matrix tmpData, alltcs, tmp, pinvdes;
Matrix tmpcont = diag(ones(IC.Nrows(),1)), s1,s2, tmpData, alltcs;
basicGLM tmpglm;
for(int ctr = 0; ctr < numfiles; ctr++){
tmpData = process_file(opts.inputfname.value().at(ctr), numfiles);
outMsize("IC",IC);
//may want to remove the spatial means first
tmpglm.olsfit(remmean(tmpData.t(),1),remmean(IC.t(),1),tmpcont);
s1=tmpglm.get_beta();
tmp = IC*IC.t(); outMsize("tmp",tmp);
pinvdes = tmp.i()*IC;
tmp = tmpData*pinvdes.t(); outMsize(string("tmp ")+opts.inputfname.value().at(ctr),tmp);
if(alltcs.Storage()==0)
alltcs=tmp;
alltcs=s1;
else
alltcs&=tmp;
alltcs&=s1;
// output DR
if(opts.dr_out.value()){
write_ascii_matrix(drO.appendDir("dr_stage1_subject"+num2str(ctr)+".txt"),tmp);
pinvdes = tmp.t()*tmp;
pinvdes = pinvdes.i();
tmp = pinvdes*tmp.t()*tmpData;
save4D(tmp,string("dr/dr_stage2_subject"+num2str(ctr)));
write_ascii_matrix(drO.appendDir("dr_stage1_subject"+num2str(ctr)+".txt"),s1);
//des_norm
s1 = SP(s1,ones(s1.Nrows(),1)*pow(stdev(s1,1),-1));
tmpglm.olsfit(remmean(tmpData),remmean(s1,1),tmpcont);
s2=tmpglm.get_z();
save4D(s2,string("dr/dr_stage2_subject"+num2str(ctr)));
}
}
for(int ctr = 1; ctr <= alltcs.Ncols(); ctr++){
tmp << alltcs.Column(ctr);
add_Tmodes(tmp);
tmpcont << alltcs.Column(ctr);
add_Tmodes(tmpcont);
}
opts.varnorm.set_T(tmpvarnorm);
......
......@@ -83,6 +83,9 @@ int main(int argc, char *argv[]){
melodat.setup();
if (opts.maxRestart.value()<0)
opts.maxRestart.set_T(melodat.data_dim());
do{
//do PCA pre-processing
MelodicPCA pcaobj(melodat,opts,logger,report);
......@@ -95,40 +98,27 @@ int main(int argc, char *argv[]){
no_conv = icaobj.no_convergence;
opts.maxNumItt.set_T(500);
if((opts.approach.value()=="symm")&&(retry > std::min(opts.retrystep,3)))
{
if(no_conv){
retry++;
opts.approach.set_T("defl");
message(endl << "Restarting MELODIC using deflation approach"
<< endl << endl);
}else{
leaveloop = true;
}
}else{
if(no_conv){
retry++;
if(opts.pca_dim.value()-retry*opts.retrystep >
0.1*melodat.data_dim()){
opts.pca_dim.set_T(opts.pca_dim.value()-retry*opts.retrystep);
}
else{
if(opts.pca_dim.value()+retry*opts.retrystep < melodat.data_dim()){
opts.pca_dim.set_T(opts.pca_dim.value()+retry*opts.retrystep);
}else{
leaveloop = true; //stupid, but break does not compile
//on all platforms
}
}
if(!leaveloop){
if(opts.paradigmfname.value().length()>0)
opts.pca_dim.set_T(std::max(opts.pca_dim.value(),melodat.get_param().Ncols()+3*opts.retrystep-1));
message(endl << "Restarting MELODIC using -d "
<< opts.pca_dim.value()
<< endl << endl);
}
}
if(no_conv){
retry++;
if((opts.approach.value()=="symm")&&(retry == opts.maxRestart.value())){
// try final round with defl
opts.approach.set_T("defl");
message(endl << "Restarting MELODIC using deflation approach" << endl << endl);
}
else{
// try using different dim
if((int)opts.pca_dim.value()*opts.retryfactor.value() > (int)(0.05*melodat.data_dim()+1)){
opts.pca_dim.set_T((int)opts.pca_dim.value()*opts.retryfactor.value());
}
else{
if((int)opts.pca_dim.value()/opts.retryfactor.value() > (int)(melodat.data_dim())){
opts.pca_dim.set_T((int)opts.pca_dim.value()/opts.retryfactor.value());
}
else{
leaveloop = TRUE;
}
}
}
}
} while (no_conv && retry<opts.maxRestart.value() && !leaveloop);
......
......@@ -133,7 +133,7 @@ class MelodicOptions {
Option<bool> guess_remderiv;
Option<bool> temporal;
int retrystep;
Option<float> retryfactor;
void parse_command_line(int argc, char** argv, Log& logger, const string &p_version);
......@@ -254,7 +254,7 @@ class MelodicOptions {
maxNumItt(string("--maxit"), 500,
string("\tmaximum number of iterations before restart"),
false, requires_argument),
maxRestart(string("--maxrestart"), 6,
maxRestart(string("--maxrestart"), -1,
string("maximum number of restarts\n"),
false, requires_argument),
rank1interval(string("--rank1interval"), 10,
......@@ -404,7 +404,9 @@ class MelodicOptions {
temporal(string("--temporal"), false,
string("perform temporal ICA"),
false, no_argument, false),
retrystep(3),
retryfactor(string("--retryfactor"), float(0.95),
string("multiplicative factor for determining new dim if estimated dim fails to converge"),
false, requires_argument, false),
options(title, usageexmpl)
{
try {
......@@ -491,6 +493,7 @@ class MelodicOptions {
options.add(rescale_nht);
options.add(guess_remderiv);
options.add(temporal);
options.add(retryfactor);
}
catch(X_OptionError& e) {
options.usage();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment