From e2490dec302a41ff751ded191da6e73ed9b3235a Mon Sep 17 00:00:00 2001 From: Mark Woolrich <woolrich@fmrib.ox.ac.uk> Date: Thu, 26 Feb 2004 17:14:37 +0000 Subject: [PATCH] *** empty log message *** --- minimize.cc | 81 ++++++++++++++++++++++++++++------------------------- minimize.h | 4 +-- sparsefn.cc | 17 ++++++----- 3 files changed, 54 insertions(+), 48 deletions(-) diff --git a/minimize.cc b/minimize.cc index 29c2c6c..a6dd504 100644 --- a/minimize.cc +++ b/minimize.cc @@ -367,11 +367,12 @@ void scg(ColumnVector& x,const gEvalFunction& func){ bool success=true; int nsuccess=0; float lambda=1.0; - float lambdamin = 1.0e-15; - float lambdamax = 1.0e100; + float lambdamin = 1.0e-15; + float lambdamax = 1.0e15; int j = 1; float mu=0,kappa=0,sigma=0,gamma=0,alpha=0,delta=0,Delta,beta=0; float eps=1e-16; + float tol=0.0000001; // main loop.. while(j<niters){ @@ -381,15 +382,21 @@ void scg(ColumnVector& x,const gEvalFunction& func){ d=-gradnew; mu=(d.t()*gradnew).AsScalar(); } + kappa=(d.t()*d).AsScalar(); if(kappa<eps){ break; } - sigma=sigma0/std::sqrt(kappa); + + sigma=sigma0/std::sqrt(kappa); xplus = x + sigma*d; + gplus=func.g_evaluate(xplus);gevals++; + gamma = (d.t()*(gplus - gradnew)).AsScalar()/sigma; + } + delta = gamma + lambda*kappa; if (delta <= 0){ delta = lambda*kappa; @@ -397,47 +404,45 @@ void scg(ColumnVector& x,const gEvalFunction& func){ } alpha = - mu/delta; + xnew = x + alpha*d; + fnew=func.evaluate(xnew);fevals++; + Delta = 2*(fnew - fold)/(alpha*mu); - xnew = x + alpha*d; - fnew=func.evaluate(xnew);fevals++; - Delta = 2*(fnew - fold)/(alpha*mu); - if (Delta >= 0){ - success = true; - nsuccess = nsuccess + 1; + if (Delta >= 0){ + success = true; + nsuccess = nsuccess + 1; x = xnew; fnow = fnew;} - else{ - success = false; - fnow = fold; - } - - if (success == 1){ - - - //Test for termination... - if(0==1){ - break; - } - else{ - fold = fnew; - gradold = gradnew; - gradnew=func.g_evaluate(x);gevals++; - if ((gradnew.t()*gradnew).AsScalar() == 0){ - break; - } - } - - } - - if (Delta < 0.25){ - // lambda = min(4.0*lambda, lambdamax); - lambda=4.0*lambda<lambdamax ? 4.0*lambda : lambdamax; - } + else{ + success = false; + fnow = fold; + } + + if (success == 1){ + + //Test for termination... + + if ( (max(abs(d*alpha))).AsScalar() < tol && std::abs(fnew-fold) < tol){ + break; + } + else{ + fold = fnew; + gradold = gradnew; + gradnew=func.g_evaluate(x);gevals++; + if ((gradnew.t()*gradnew).AsScalar() == 0){ + break; + } + } + + } + if (Delta < 0.25){ + // lambda = min(4.0*lambda, lambdamax); + lambda=4.0*lambda<lambdamax ? 4.0*lambda : lambdamax; + } if (Delta > 0.75){ //lambda = max(0.5*lambda, lambdamin); lambda = 0.5*lambda > lambdamin ? 0.5*lambda : lambdamin; } - if (nsuccess == nparams){ d = -gradnew; @@ -451,7 +456,7 @@ void scg(ColumnVector& x,const gEvalFunction& func){ } j++; } - + } diff --git a/minimize.h b/minimize.h index 918fc10..c0c67b9 100644 --- a/minimize.h +++ b/minimize.h @@ -48,15 +48,13 @@ public: virtual float evaluate(const ColumnVector& x) const = 0; //evaluate the function virtual ~EvalFunction(){}; -private: - +private: const EvalFunction& operator=(EvalFunction& par); EvalFunction(const EvalFunction&); }; class gEvalFunction : public EvalFunction {//Function where gradient is analytic (required for scg) - public: gEvalFunction() : EvalFunction(){} // evaluate is inherited from EvalFunction diff --git a/sparsefn.cc b/sparsefn.cc index 0ccdcfe..625d45c 100644 --- a/sparsefn.cc +++ b/sparsefn.cc @@ -345,15 +345,17 @@ namespace MISCMATHS { { Tracer_Plus trace("sparsefns::solvefortracex"); - int every = Max(1,A.Ncols()/50); + int every = Max(1,A.Ncols()/100); + //int every = 1; + OUT(every); float tr = 0.0; // assumes symmetric A and b for(int r = every; r<=A.Ncols(); r+=every) { - cout << float(r)/A.Ncols() << "\r"; - cout.flush(); +// cout << float(r)/A.Ncols() << "\r"; +// cout.flush(); ColumnVector br = b.RowAsColumn(r); ColumnVector xr = x.RowAsColumn(r); @@ -502,11 +504,12 @@ namespace MISCMATHS { } - if(k>20) + if(k>kmax/2.0) { - OUT(std::sqrt(rho(k-1))); - OUT(norm2(b)); - OUT(k); + OUT(std::sqrt(rho(k-1))); + OUT(norm2(b)); + OUT(k); + cout.flush(); } } -- GitLab