From ba42d0b63a9086eca4c111546f6dc1e8c5898a8f Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Fri, 6 Sep 2024 17:26:45 +0100
Subject: [PATCH] ENH: NEW "SoftZeros" extrapolation method, which sets
 out-of-bounds locations to zero, but still interpolates out-of-bounds, to
 allow a smooth transition from in-to-out of bounds

---
 splinterpolator.h | 104 +++++++++++++++++++++++++++++-----------------
 1 file changed, 65 insertions(+), 39 deletions(-)

diff --git a/splinterpolator.h b/splinterpolator.h
index 2fab71c..dfc312b 100644
--- a/splinterpolator.h
+++ b/splinterpolator.h
@@ -22,7 +22,14 @@
 
 namespace SPLINTERPOLATOR {
 
-enum ExtrapolationType {Zeros, Constant, Mirror, Periodic};
+// Controls how the field behaves outside of the FOV:
+//  - Zeros:     Out-of-bounds coordinates are set to zero
+//  - Constant:  Out-of-bounds coordinates are set to the boundary voxels
+//  - Mirror:    The field is mirrored in all directions
+//  - Periodic:  The field is repeated in all directions
+//  - SoftZeros: The field boundary is set to zero, but out-of-bounds voxels
+//               are calculated via interpolation
+enum ExtrapolationType {Zeros, Constant, Mirror, Periodic, SoftZeros};
 
 class SplinterpolatorException: public std::exception
 {
@@ -657,27 +664,35 @@ double Splinterpolator<T>::value_at(const double *coord) const
 
   double       iwgt[8], jwgt[8], kwgt[8], lwgt[8], mwgt[8];
   double       *wgts[] = {iwgt, jwgt, kwgt, lwgt, mwgt};
+  int          start_inds[5];
   int          inds[5];
   unsigned int ni = 0;
   const T      *cptr = coef_ptr();
 
-  ni = get_start_indicies(coord,inds);
-  get_wgts(coord,inds,wgts);
+  ni = get_start_indicies(coord,start_inds);
+  get_wgts(coord,start_inds,wgts);
 
   double val=0.0;
   for (unsigned int m=0, me=(_ndim>4)?ni:1; m<me; m++) {
+    inds[4] = start_inds[4] + m;
+
     for (unsigned int l=0, le=(_ndim>3)?ni:1; l<le; l++) {
+      inds[3] = start_inds[3] + l;
+
       for (unsigned int k=0, ke=(_ndim>2)?ni:1; k<ke; k++) {
+        inds[2] = start_inds[2] + k;
         double wgt1 = wgts[4][m]*wgts[3][l]*wgts[2][k];
-        unsigned int linear1 = indx2linear(inds[2]+k,inds[3]+l,inds[4]+m);
+
         for (unsigned int j=0, je=(_ndim>1)?ni:1; j<je; j++) {
+          inds[1] = start_inds[1] + j;
           double wgt2 = wgt1*wgts[1][j];
-          int linear2 = add2linear(linear1,inds[1]+j);
-          double *iiwgt=iwgt;
+          double *iiwgt = iwgt;
+
           for (unsigned int i=0; i<ni; i++, iiwgt++) {
-	    val += cptr[linear2+indx2indx(inds[0]+i,0)]*(*iiwgt)*wgt2;
+            inds[0] = start_inds[0] + i;
+            val += coef(inds) * (*iiwgt) * wgt2;
           }
-	}
+        }
       }
     }
   }
@@ -696,9 +711,9 @@ double Splinterpolator<T>::value_at(const double *coord) const
 
 template<class T>
 double Splinterpolator<T>::value_and_derivatives_at(const double       *coord,
-						    const unsigned int *deriv,
-						    double             *dval)
-const
+                                                    const unsigned int *deriv,
+                                                    double             *dval)
+  const
 {
   if (should_be_zero(coord)) { memset(dval,0,n_nonzero(deriv)*sizeof(double)); return(0.0); }
 
@@ -708,38 +723,47 @@ const
   double       *dwgts[] = {diwgt, djwgt, dkwgt, dlwgt, dmwgt};
   double       dwgt1[5];
   double       dwgt2[5];
+  int          start_inds[5];
   int          inds[5];
   unsigned int dd[5];
   unsigned int nd = 0;
   unsigned int ni = 0;
   const T      *cptr = coef_ptr();
 
-  ni = get_start_indicies(coord,inds);
-  get_wgts(coord,inds,wgts);
-  get_dwgts(coord,inds,deriv,dwgts);
+  ni = get_start_indicies(coord,start_inds);
+  get_wgts(coord,start_inds,wgts);
+  get_dwgts(coord,start_inds,deriv,dwgts);
   for (unsigned int i=0; i<_ndim; i++) if (deriv[i]) { dd[nd] = i; dval[nd++] = 0.0; }
 
   double val=0.0;
   for (unsigned int m=0, me=(_ndim>4)?ni:1; m<me; m++) {
+    inds[4] = start_inds[4] + m;
+
     for (unsigned int l=0, le=(_ndim>3)?ni:1; l<le; l++) {
+      inds[3] = start_inds[3] + l;
+
       for (unsigned int k=0, ke=(_ndim>2)?ni:1; k<ke; k++) {
+        inds[2] = start_inds[2] + k;
         double wgt1 = wgts[4][m]*wgts[3][l]*wgts[2][k];
         get_dwgt1(wgts,dwgts,dd,nd,k,l,m,wgt1,dwgt1);
-        unsigned int linear1 = indx2linear(inds[2]+k,inds[3]+l,inds[4]+m);
+
         for (unsigned int j=0, je=(_ndim>1)?ni:1; j<je; j++) {
+          inds[1] = start_inds[1] + j;
           double wgt2 = wgt1*wgts[1][j];
           for (unsigned int d=0; d<nd; d++) dwgt2[d] = (dd[d]==1) ? dwgt1[d]*dwgts[1][j] : dwgt1[d]*wgts[1][j];
-          int linear2 = add2linear(linear1,inds[1]+j);
           double *iiwgt=iwgt;
+
           for (unsigned int i=0; i<ni; i++, iiwgt++) {
-            double c = cptr[linear2+indx2indx(inds[0]+i,0)];
+            inds[0] = start_inds[0] + i;
+            double c = coef(inds);
             val += c*(*iiwgt)*wgt2;
+
             for (unsigned int d=0; d<nd; d++) {
               double add = (dd[d]==0) ? c*diwgt[i]*dwgt2[d] : c*(*iiwgt)*dwgt2[d];
               dval[d] += add;
-	    }
-	  }
-	}
+            }
+          }
+        }
       }
     }
   }
@@ -1200,7 +1224,7 @@ inline unsigned int Splinterpolator<T>::indx2indx(int indx, unsigned int d) cons
     if (indx < 0) indx = 0;
     else if (indx >= dim) indx = dim-1;
   }
-  else if (_et[d] == Zeros || _et[d] == Mirror) {
+  else if (_et[d] == Zeros || _et[d] == SoftZeros || _et[d] == Mirror) {
     while (indx < 0) indx = 2*dim*((indx+1)/dim) - 1 - indx;
     while (indx >= dim) indx = 2*dim*(indx/dim) - 1 - indx;
   }
@@ -1323,36 +1347,38 @@ T Splinterpolator<T>::coef(int *indx) const
   for (unsigned int i=0; i<_ndim; i++) {
     if (indx[i] < 0) {
       switch (_et[i]) {
+      case SoftZeros:
       case Zeros:
-	return(static_cast<T>(0));
+        return(static_cast<T>(0));
       case Constant:
-	indx[i] = 0;
-	break;
+        indx[i] = 0;
+        break;
       case Mirror:
-	indx[i] = 1-indx[i];
-	break;
+        indx[i] = 1-indx[i];
+        break;
       case Periodic:
-	indx[i] = _dim[i]+indx[i];
-	break;
+        indx[i] = _dim[i]+indx[i];
+        break;
       default:
-	break;
+        break;
       }
     }
     else if (indx[i] >= static_cast<int>(_dim[i])) {
       switch (_et[i]) {
+      case SoftZeros:
       case Zeros:
-	return(static_cast<T>(0));
+        return(static_cast<T>(0));
       case Constant:
-	indx[i] = _dim[i]-1;
-	break;
+        indx[i] = _dim[i]-1;
+        break;
       case Mirror:
-	indx[i] = 2*_dim[i]-indx[i]-1;
-	break;
+        indx[i] = 2*_dim[i]-indx[i]-1;
+        break;
       case Periodic:
-	indx[i] = indx[i]-_dim[i];
-	break;
+        indx[i] = indx[i]-_dim[i];
+        break;
       default:
-	break;
+        break;
       }
     }
   }
@@ -1648,7 +1674,7 @@ double Splinterpolator<T>::SplineColumn::init_fwd_sweep(double z, ExtrapolationT
     double z2i=z;
     for (unsigned int i=1; i<n; i++, ptr++, z2i*=z) iv += z2i * *ptr;
   }
-  // et == Constant || et == Zeros
+  // et == Constant || et == Zeros || et == SoftZeros
   else {
     double *ptr=&_col[0];
     double z2i=z;
@@ -1683,7 +1709,7 @@ double Splinterpolator<T>::SplineColumn::init_bwd_sweep(double z, double lv, Ext
   else if (et == Mirror) {
     iv = -z/(1.0-z*z) * (2.0*_col[_sz-1] - lv);
   }
-  // et == Constant || et == Zeros
+  // et == Constant || et == Zeros || et == SoftZeros
   else  {
     iv = z / (z - 1) * _col[_sz-1];
   }
-- 
GitLab