Skip to content
Snippets Groups Projects
test_splinterpolator.cc 4.95 KiB
Newer Older
#define EXPOSE_TREACHEROUS

#include "miscmaths/splinterpolator.h"
#include "NewNifti/NewNifti.h"

#include <armadillo>

#include <filesystem>
#include <stdlib.h>
#include <vector>

#define BOOST_TEST_MODULE test_splinterpolator
#include <boost/test/included/unit_test.hpp>

namespace fs  = std::filesystem;
namespace BTF = boost::unit_test::framework;
namespace SPL = SPLINTERPOLATOR;

class TestFixture {

public:

  // 5x5x5 volume data to be interpolated
  std::vector<unsigned int> dims{5, 5, 5};
  std::vector<float>        data{
    1,   2,   3,   4,   5,   6,   7,   8,   9,   10,
    11,  12,  13,  14,  15,  16,  17,  18,  19,  20,
    21,  22,  23,  24,  25,  26,  27,  28,  29,  30,
    31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
    41,  42,  43,  44,  45,  46,  47,  48,  49,  50,
    51,  52,  53,  54,  55,  56,  57,  58,  59,  60,
    61,  62,  63,  64,  65,  66,  67,  68,  69,  70,
    71,  72,  73,  74,  75,  76,  77,  78,  79,  80,
    81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
    91,  92,  93,  94,  95,  96,  97,  98,  99,  100,
    101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
    111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
    121, 122, 123, 124, 125
  };

  // sample points, initialised in constructor
  size_t             npts;
  std::vector<float> xpts;
  std::vector<float> ypts;
  std::vector<float> zpts;

  // size of up-sampled volume, initialised in constructor
  size_t             xsz;
  size_t             ysz;
  size_t             zsz;

  // sampled values and gradient, initialised in constructor,
  // populated in test cases
  std::vector<float>              values;
  std::vector<std::vector<float>> gradient;


  TestFixture() {
    // the data is up-sampled by 3x
    for (float z = -4; z < 9.01; z+=1/3.0) {
    for (float y = -4; y < 9.01; y+=1/3.0) {
    for (float x = -4; x < 9.01; x+=1/3.0) {
      if (std::abs(x) < 0.01) x = 0;
      if (std::abs(y) < 0.01) y = 0;
      if (std::abs(z) < 0.01) z = 0;
      xpts.push_back(x);
      ypts.push_back(y);
      zpts.push_back(z);
    }}}
    npts = xpts.size();
    xsz  = 40;
    ysz  = 40;
    zsz  = 40;

    // pre-allocate space to store test outputs
    values = std::vector<float>(npts);
    gradient.push_back(std::vector<float>(npts));
    gradient.push_back(std::vector<float>(npts));
    gradient.push_back(std::vector<float>(npts));
  };

  fs::path test_data_dir() {
    return fs::path(MISCMATHS_TEST_DIRECTORY) / "test_splinterpolator";
  }

  // load a NIfTI image from the test data directory
  std::vector<float> load_nifti(std::string filename) {
    char*   buf;
    float* fbuf;
    std::vector<NiftiIO::NiftiExtension> exts;

    auto hdr = NiftiIO::loadImage(test_data_dir() / filename, buf, exts);

    fbuf = reinterpret_cast<float*>(buf);

    std::vector<float> values(hdr.nElements());

    for (auto i = 0; i < hdr.nElements(); i++) {
      values[i] = fbuf[i];
    }

    delete [] buf;
    return values;
  }

  // compare two float vectors
  void compare(std::vector<float> a, std::vector<float> b, float tol=1e-5) {
    BOOST_CHECK_EQUAL(a.size(), b.size());
    for (auto i = 0; i < a.size(); i++) {
      BOOST_CHECK_SMALL(a[i] - b[i], tol);
    }
  }

  // Run the test with a Splinterpolator instance
  void run_test(const SPL::Splinterpolator<float>& spl) {

    std::string test_name = BTF::current_test_case().p_name;

    std::vector<float> deriv{0, 0, 0};

    for (auto i = 0; i < npts; i++) {
      values[i] = spl(xpts[i], ypts[i], zpts[i]);
      spl.ValAndDerivs(xpts[i], ypts[i], zpts[i], deriv);
      gradient[0][i] = deriv[0];
      gradient[1][i] = deriv[1];
      gradient[2][i] = deriv[2];
    }

    compare(values, load_nifti(test_name + "_values.nii.gz"));

    // derivatives are only valid for interpolation of order > 1
    if (spl.Order() > 1) {
      std::vector<float> gradvals;
      gradvals.insert(gradvals.end(), gradient[0].begin(), gradient[0].end());
      gradvals.insert(gradvals.end(), gradient[1].begin(), gradient[1].end());
      gradvals.insert(gradvals.end(), gradient[2].begin(), gradient[2].end());
      compare(gradvals, load_nifti(test_name + "_gradient.nii.gz"));
    }
  }
};


BOOST_FIXTURE_TEST_SUITE(test_splinterpolator, TestFixture)


BOOST_AUTO_TEST_CASE(order_1_3d_extrap_zeros) {
  run_test(SPL::Splinterpolator<float>(data.data(), dims, SPL::Zeros, 1));
}

BOOST_AUTO_TEST_CASE(order_1_3d_extrap_soft_zeros) {
  run_test(SPL::Splinterpolator<float>(data.data(), dims, SPL::SoftZeros, 1));
}

BOOST_AUTO_TEST_CASE(order_1_3d_extrap_constant) {
  run_test(SPL::Splinterpolator<float>(data.data(), dims, SPL::Constant, 1));
}

BOOST_AUTO_TEST_CASE(order_3_3d_extrap_zeros) {
  run_test(SPL::Splinterpolator<float>(data.data(), dims, SPL::Zeros, 3));
}

BOOST_AUTO_TEST_CASE(order_3_3d_extrap_soft_zeros) {
  run_test(SPL::Splinterpolator<float>(data.data(), dims, SPL::SoftZeros, 3));
}

BOOST_AUTO_TEST_CASE(order_3_3d_extrap_constant) {
  run_test(SPL::Splinterpolator<float>(data.data(), dims, SPL::Constant, 3));
}

BOOST_AUTO_TEST_SUITE_END()