Skip to content
Snippets Groups Projects
Commit e2490dec authored by Mark Woolrich's avatar Mark Woolrich
Browse files

*** empty log message ***

parent 897d824e
No related branches found
No related tags found
No related merge requests found
...@@ -367,11 +367,12 @@ void scg(ColumnVector& x,const gEvalFunction& func){ ...@@ -367,11 +367,12 @@ void scg(ColumnVector& x,const gEvalFunction& func){
bool success=true; bool success=true;
int nsuccess=0; int nsuccess=0;
float lambda=1.0; float lambda=1.0;
float lambdamin = 1.0e-15; float lambdamin = 1.0e-15;
float lambdamax = 1.0e100; float lambdamax = 1.0e15;
int j = 1; int j = 1;
float mu=0,kappa=0,sigma=0,gamma=0,alpha=0,delta=0,Delta,beta=0; float mu=0,kappa=0,sigma=0,gamma=0,alpha=0,delta=0,Delta,beta=0;
float eps=1e-16; float eps=1e-16;
float tol=0.0000001;
// main loop.. // main loop..
while(j<niters){ while(j<niters){
...@@ -381,15 +382,21 @@ void scg(ColumnVector& x,const gEvalFunction& func){ ...@@ -381,15 +382,21 @@ void scg(ColumnVector& x,const gEvalFunction& func){
d=-gradnew; d=-gradnew;
mu=(d.t()*gradnew).AsScalar(); mu=(d.t()*gradnew).AsScalar();
} }
kappa=(d.t()*d).AsScalar(); kappa=(d.t()*d).AsScalar();
if(kappa<eps){ if(kappa<eps){
break; break;
} }
sigma=sigma0/std::sqrt(kappa);
sigma=sigma0/std::sqrt(kappa);
xplus = x + sigma*d; xplus = x + sigma*d;
gplus=func.g_evaluate(xplus);gevals++; gplus=func.g_evaluate(xplus);gevals++;
gamma = (d.t()*(gplus - gradnew)).AsScalar()/sigma; gamma = (d.t()*(gplus - gradnew)).AsScalar()/sigma;
} }
delta = gamma + lambda*kappa; delta = gamma + lambda*kappa;
if (delta <= 0){ if (delta <= 0){
delta = lambda*kappa; delta = lambda*kappa;
...@@ -397,47 +404,45 @@ void scg(ColumnVector& x,const gEvalFunction& func){ ...@@ -397,47 +404,45 @@ void scg(ColumnVector& x,const gEvalFunction& func){
} }
alpha = - mu/delta; alpha = - mu/delta;
xnew = x + alpha*d;
fnew=func.evaluate(xnew);fevals++;
Delta = 2*(fnew - fold)/(alpha*mu);
xnew = x + alpha*d; if (Delta >= 0){
fnew=func.evaluate(xnew);fevals++; success = true;
Delta = 2*(fnew - fold)/(alpha*mu); nsuccess = nsuccess + 1;
if (Delta >= 0){
success = true;
nsuccess = nsuccess + 1;
x = xnew; x = xnew;
fnow = fnew;} fnow = fnew;}
else{ else{
success = false; success = false;
fnow = fold; fnow = fold;
} }
if (success == 1){ if (success == 1){
//Test for termination...
//Test for termination...
if(0==1){ if ( (max(abs(d*alpha))).AsScalar() < tol && std::abs(fnew-fold) < tol){
break; break;
} }
else{ else{
fold = fnew; fold = fnew;
gradold = gradnew; gradold = gradnew;
gradnew=func.g_evaluate(x);gevals++; gradnew=func.g_evaluate(x);gevals++;
if ((gradnew.t()*gradnew).AsScalar() == 0){ if ((gradnew.t()*gradnew).AsScalar() == 0){
break; break;
} }
} }
} }
if (Delta < 0.25){
if (Delta < 0.25){ // lambda = min(4.0*lambda, lambdamax);
// lambda = min(4.0*lambda, lambdamax); lambda=4.0*lambda<lambdamax ? 4.0*lambda : lambdamax;
lambda=4.0*lambda<lambdamax ? 4.0*lambda : lambdamax; }
}
if (Delta > 0.75){ if (Delta > 0.75){
//lambda = max(0.5*lambda, lambdamin); //lambda = max(0.5*lambda, lambdamin);
lambda = 0.5*lambda > lambdamin ? 0.5*lambda : lambdamin; lambda = 0.5*lambda > lambdamin ? 0.5*lambda : lambdamin;
} }
if (nsuccess == nparams){ if (nsuccess == nparams){
d = -gradnew; d = -gradnew;
...@@ -451,7 +456,7 @@ void scg(ColumnVector& x,const gEvalFunction& func){ ...@@ -451,7 +456,7 @@ void scg(ColumnVector& x,const gEvalFunction& func){
} }
j++; j++;
} }
} }
......
...@@ -48,15 +48,13 @@ public: ...@@ -48,15 +48,13 @@ public:
virtual float evaluate(const ColumnVector& x) const = 0; //evaluate the function virtual float evaluate(const ColumnVector& x) const = 0; //evaluate the function
virtual ~EvalFunction(){}; virtual ~EvalFunction(){};
private: private:
const EvalFunction& operator=(EvalFunction& par); const EvalFunction& operator=(EvalFunction& par);
EvalFunction(const EvalFunction&); EvalFunction(const EvalFunction&);
}; };
class gEvalFunction : public EvalFunction class gEvalFunction : public EvalFunction
{//Function where gradient is analytic (required for scg) {//Function where gradient is analytic (required for scg)
public: public:
gEvalFunction() : EvalFunction(){} gEvalFunction() : EvalFunction(){}
// evaluate is inherited from EvalFunction // evaluate is inherited from EvalFunction
......
...@@ -345,15 +345,17 @@ namespace MISCMATHS { ...@@ -345,15 +345,17 @@ namespace MISCMATHS {
{ {
Tracer_Plus trace("sparsefns::solvefortracex"); Tracer_Plus trace("sparsefns::solvefortracex");
int every = Max(1,A.Ncols()/50); int every = Max(1,A.Ncols()/100);
//int every = 1;
OUT(every);
float tr = 0.0; float tr = 0.0;
// assumes symmetric A and b // assumes symmetric A and b
for(int r = every; r<=A.Ncols(); r+=every) for(int r = every; r<=A.Ncols(); r+=every)
{ {
cout << float(r)/A.Ncols() << "\r"; // cout << float(r)/A.Ncols() << "\r";
cout.flush(); // cout.flush();
ColumnVector br = b.RowAsColumn(r); ColumnVector br = b.RowAsColumn(r);
ColumnVector xr = x.RowAsColumn(r); ColumnVector xr = x.RowAsColumn(r);
...@@ -502,11 +504,12 @@ namespace MISCMATHS { ...@@ -502,11 +504,12 @@ namespace MISCMATHS {
} }
if(k>20) if(k>kmax/2.0)
{ {
OUT(std::sqrt(rho(k-1))); OUT(std::sqrt(rho(k-1)));
OUT(norm2(b)); OUT(norm2(b));
OUT(k); OUT(k);
cout.flush();
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment