Commit 81f416bb authored by William Clarke's avatar William Clarke
Browse files

Merge branch 'bf/improved_basis_interp' into 'master'

Bf/improved basis interpolation

See merge request fsl/fsl_mrs!41
parents b1a83e76 39d9c2e0
Pipeline #12651 waiting for manual action with stages
in 20 minutes and 47 seconds
......@@ -4,6 +4,7 @@ This document contains the FSL-MRS release history in reverse chronological orde
----------------------------------
- Updates to fsl_mrs_preproc_edit
- Updated install documentation.
- Implemented new fft based interpolation of basis sets. Improves suppression of interpolation aliasing.
1.1.9 (Tuesday 30th November 2021)
----------------------------------
......
.. _conda:
:orphan:
===========
......@@ -8,13 +9,13 @@ Conda Guide
This is a short guide on setting up conda for the first time.
1. Download and install a python 3.7 version of Miniconda from the `package website <https://docs.conda.io/en/latest/miniconda.html>`_.
2. Create a conda enviroment.
2. Create a conda environment.
::
conda create --name fsl_mrs -c conda-forge python=3.7
3. Activate the enviroment.
3. Activate the environment.
::
......
......@@ -80,6 +80,9 @@ class Basis:
# This only has bearing on the plotting currently
self._nucleus = '1H'
# Default interpolation is Fourier Transform based.
self._use_fourier_interp = True
@classmethod
def from_file(cls, filepath):
"""Create a Basis object from a path
......@@ -174,6 +177,17 @@ class Basis:
"""Set the nucleus string - only affects plotting"""
self._nucleus = nucleus
@property
def use_fourier_interp(self):
"""Return interpolation state"""
return self._use_fourier_interp
@use_fourier_interp.setter
def use_fourier_interp(self, true_false):
"""Set to true to use FFT based interpolation (default)
Or set to False to use time domain linear interpolation."""
self._use_fourier_interp = true_false
def save(self, out_path, overwrite=False, info_str=''):
"""Saves basis held in memory to a directory in FSL-MRS format.
......@@ -299,10 +313,16 @@ class Basis:
coverage than the FID.
"""
try:
basis = misc.ts_to_ts(self._raw_fids,
self.original_dwell,
target_dwell,
target_points)
if self.use_fourier_interp:
basis = misc.ts_to_ts_ft(self._raw_fids,
self.original_dwell,
target_dwell,
target_points)
else:
basis = misc.ts_to_ts(self._raw_fids,
self.original_dwell,
target_dwell,
target_points)
except misc.InsufficentTimeCoverageError:
raise BasisHasInsufficentCoverage('The basis spectra covers too little time. '
'Please reduce the dwelltime, number of points or pad this basis.')
......
......@@ -123,6 +123,40 @@ def test_formatting():
assert np.isclose(np.linalg.norm(np.mean(no_scale * rescale[0], axis=1)), 100)
def test_formatting_linear_interp():
original = basis_mod.Basis.from_file(fsl_basis_path)
original.use_fourier_interp = False
with pytest.raises(basis_mod.BasisHasInsufficentCoverage) as exc_info:
original.get_formatted_basis(2000, 2048)
assert exc_info.type is basis_mod.BasisHasInsufficentCoverage
assert exc_info.value.args[0] == 'The basis spectra covers too little time. '\
'Please reduce the dwelltime, number of points or pad this basis.'
basis = original.get_formatted_basis(2000, 1024)
assert basis.shape == (1024, 21)
basis = original.get_formatted_basis(2000, 1024, ignore=['Ins', 'Cr'])
assert basis.shape == (1024, 19)
basis = original.get_formatted_basis(2000, 1024, ignore=['Ins', 'Cr'], scale_factor=100)
assert np.isclose(np.linalg.norm(np.mean(basis, axis=1)), 100)
names = original.get_formatted_names(ignore=['Ins', 'Cr'])
assert 'Ins' not in names
assert 'Cr' not in names
basis = original.get_formatted_basis(2000, 1024, ignore=['Ins', 'Cr'], scale_factor=100, indept_scale=['Mac'])
index = original.get_formatted_names(ignore=['Ins', 'Cr']).index('Mac')
assert np.isclose(np.linalg.norm(np.mean(np.delete(basis, index, axis=1), axis=1)), 100)
assert np.isclose(np.linalg.norm(basis[:, index]), 100)
# Test rescale
rescale = original.get_rescale_values(2000, 1024, ignore=['Ins', 'Cr'], scale_factor=100)
no_scale = original.get_formatted_basis(2000, 1024, ignore=['Ins', 'Cr'])
assert np.isclose(np.linalg.norm(np.mean(no_scale * rescale[0], axis=1)), 100)
def test_add_fid():
original = basis_mod.Basis.from_file(fsl_basis_path)
......
......@@ -110,3 +110,36 @@ def test_parse_metab_groups():
# List of integers
assert misc.parse_metab_groups(mrs, [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0])\
== [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0]
def test_interpolation():
target_bw = 2000
target_n = 1024
fid_full, hdr_full = synth.syntheticFID(bandwidth=8000, points=8192, noisecovariance=[[0.0]])
fid_reduced, hdr_reduced = synth.syntheticFID(bandwidth=target_bw, points=target_n, noisecovariance=[[0.0]])
interp_lin = misc.ts_to_ts(fid_full[0], 1 / 8000, 1 / target_bw, target_n)
interp_ft = misc.ts_to_ts_ft(fid_full[0], 1 / 8000, 1 / target_bw, target_n)
# import matplotlib.pyplot as plt
# plt.plot(hdr_full['taxis'], np.squeeze(np.real(fid_full)), '-x')
# plt.plot(hdr_reduced['taxis'], np.squeeze(np.real(fid_reduced)), '--x')
# plt.plot(hdr_reduced['taxis'], np.squeeze(np.real(interp_lin)), ':x')
# plt.plot(hdr_reduced['taxis'], np.squeeze(np.real(interp_ft)), ':x')
# plt.xlim([-0.001, 0.1])
# plt.show()
# fig = plt.figure(figsize=(15,6))
# plt.plot(hdr_full['faxis'], np.real(plot.FID2Spec(np.asarray(np.squeeze(fid_full)))), '-')
# plt.plot(hdr_reduced['faxis'], np.real(plot.FID2Spec(np.asarray(np.squeeze(fid_reduced)))), '-')
# plt.plot(hdr_reduced['faxis'], np.squeeze(np.real(plot.FID2Spec(np.asarray(interp_lin)))), ':')
# plt.plot(hdr_reduced['faxis'], np.squeeze(np.real(plot.FID2Spec(np.asarray(interp_ft)))), ':')
# plt.xlim([-500,0])
# plt.show()
assert np.allclose(interp_lin, fid_reduced[0])
# We know the first few points are corrupted in the fft version, but that will appear at edge
# of the spectrum
assert np.allclose(interp_ft[10:-10], np.asarray(fid_reduced[0])[10:-10], atol=1E-1)
......@@ -216,6 +216,75 @@ def ts_to_ts(old_ts, old_dt, new_dt, new_n):
return new_ts
def ts_to_ts_ft(old_ts, old_dt, new_dt, new_n):
"""Temporal resampling using Fourier transform based resampling
Using the method implemented in LCModel:
1. Data is padded or truncated in the spectral domain to match the bandwidth of the target.
The ifft then returns the time domain data with the right overall length.
2. The data is then padded or truncated in the time domain to the length of the target.
If the data is then FFT it return the interpolated data.
:param old_ts: Input time-domain data
:type old_ts: numpy.ndarray
:param old_dt: Input dwelltime
:type old_dt: float
:param new_dt: Target dwell time
:type new_dt: float
:param new_n: Target number of points
:type new_n: int
:rtype: numpy.ndarray
"""
old_n = old_ts.shape[0]
old_t = np.linspace(old_dt, old_dt * old_n, old_n) - old_dt
new_t = np.linspace(new_dt, new_dt * new_n, new_n) - new_dt
# Round to nanoseconds
old_t = np.round(old_t, 9)
new_t = np.round(new_t, 9)
if new_t[-1] > old_t[-1]:
raise InsufficentTimeCoverageError('Input data covers less time than is requested by interpolation.'
' Change interpolation points or dwell time.')
def f2s(x):
return np.fft.fftshift(np.fft.fft(x, axis=0), axes=0)
def s2f(x):
return np.fft.ifft(np.fft.ifftshift(x, axes=0), axis=0)
# Input data to frequency domain
old_fs = f2s(old_ts)
# Step 1: pad or truncate in the frequency domain
new_bw = 1 / new_dt
old_bw = 1 / old_dt
npoints_f = (new_bw - old_bw) / (old_bw / old_ts.shape[0])
npoints_f_half = int(np.round(npoints_f / 2))
# scale_factor = np.abs(float(npoints_f_half) * 2.0) / new_n
if npoints_f_half < 0:
# New bandwidth is smaller than old. Truncate
npoints_f_half *= -1
step1 = s2f(old_fs[npoints_f_half:-npoints_f_half])
elif npoints_f_half > 0:
# New bandwidth is larger than old. Pad
step1 = s2f(np.pad(old_fs, ((npoints_f_half, npoints_f_half), (0, 0)), 'constant', constant_values=(0j, 0j)))
else:
step1 = s2f(old_fs)
# Scaling for different length fft/ifft
step1 = step1 * step1.shape[0] / old_fs.shape[0]
# Step 2: pad or truncate in the temporal domain
if step1.shape[0] < new_n:
step2 = np.pad(step1, ((0, new_n - step1.shape[0]), (0, 0)), 'constant', constant_values=(0j, 0j))
else:
step2 = step1[:new_n]
return step2
# Numerical differentiation (light)
# Gradient Function
def gradient(x, f):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment