Commit 2a4cc33c authored by William Clarke's avatar William Clarke
Browse files

Add tests for new fft based interpolation and corect scaling in fft interp.

parent 1b114763
......@@ -123,6 +123,40 @@ def test_formatting():
assert np.isclose(np.linalg.norm(np.mean(no_scale * rescale[0], axis=1)), 100)
def test_formatting_linear_interp():
original = basis_mod.Basis.from_file(fsl_basis_path)
original.use_fourier_interp = False
with pytest.raises(basis_mod.BasisHasInsufficentCoverage) as exc_info:
original.get_formatted_basis(2000, 2048)
assert exc_info.type is basis_mod.BasisHasInsufficentCoverage
assert exc_info.value.args[0] == 'The basis spectra covers too little time. '\
'Please reduce the dwelltime, number of points or pad this basis.'
basis = original.get_formatted_basis(2000, 1024)
assert basis.shape == (1024, 21)
basis = original.get_formatted_basis(2000, 1024, ignore=['Ins', 'Cr'])
assert basis.shape == (1024, 19)
basis = original.get_formatted_basis(2000, 1024, ignore=['Ins', 'Cr'], scale_factor=100)
assert np.isclose(np.linalg.norm(np.mean(basis, axis=1)), 100)
names = original.get_formatted_names(ignore=['Ins', 'Cr'])
assert 'Ins' not in names
assert 'Cr' not in names
basis = original.get_formatted_basis(2000, 1024, ignore=['Ins', 'Cr'], scale_factor=100, indept_scale=['Mac'])
index = original.get_formatted_names(ignore=['Ins', 'Cr']).index('Mac')
assert np.isclose(np.linalg.norm(np.mean(np.delete(basis, index, axis=1), axis=1)), 100)
assert np.isclose(np.linalg.norm(basis[:, index]), 100)
# Test rescale
rescale = original.get_rescale_values(2000, 1024, ignore=['Ins', 'Cr'], scale_factor=100)
no_scale = original.get_formatted_basis(2000, 1024, ignore=['Ins', 'Cr'])
assert np.isclose(np.linalg.norm(np.mean(no_scale * rescale[0], axis=1)), 100)
def test_add_fid():
original = basis_mod.Basis.from_file(fsl_basis_path)
......
......@@ -110,3 +110,36 @@ def test_parse_metab_groups():
# List of integers
assert misc.parse_metab_groups(mrs, [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0])\
== [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0]
def test_interpolation():
target_bw = 2000
target_n = 1024
fid_full, hdr_full = synth.syntheticFID(bandwidth=8000, points=8192, noisecovariance=[[0.0]])
fid_reduced, hdr_reduced = synth.syntheticFID(bandwidth=target_bw, points=target_n, noisecovariance=[[0.0]])
interp_lin = misc.ts_to_ts(fid_full[0], 1 / 8000, 1 / target_bw, target_n)
interp_ft = misc.ts_to_ts_ft(fid_full[0], 1 / 8000, 1 / target_bw, target_n)
# import matplotlib.pyplot as plt
# plt.plot(hdr_full['taxis'], np.squeeze(np.real(fid_full)), '-x')
# plt.plot(hdr_reduced['taxis'], np.squeeze(np.real(fid_reduced)), '--x')
# plt.plot(hdr_reduced['taxis'], np.squeeze(np.real(interp_lin)), ':x')
# plt.plot(hdr_reduced['taxis'], np.squeeze(np.real(interp_ft)), ':x')
# plt.xlim([-0.001, 0.1])
# plt.show()
# fig = plt.figure(figsize=(15,6))
# plt.plot(hdr_full['faxis'], np.real(plot.FID2Spec(np.asarray(np.squeeze(fid_full)))), '-')
# plt.plot(hdr_reduced['faxis'], np.real(plot.FID2Spec(np.asarray(np.squeeze(fid_reduced)))), '-')
# plt.plot(hdr_reduced['faxis'], np.squeeze(np.real(plot.FID2Spec(np.asarray(interp_lin)))), ':')
# plt.plot(hdr_reduced['faxis'], np.squeeze(np.real(plot.FID2Spec(np.asarray(interp_ft)))), ':')
# plt.xlim([-500,0])
# plt.show()
assert np.allclose(interp_lin, fid_reduced[0])
# We know the first few points are corrupted in the fft version, but that will appear at edge
# of the spectrum
assert np.allclose(interp_ft[10:-10], np.asarray(fid_reduced[0])[10:-10], atol=1E-1)
......@@ -262,20 +262,23 @@ def ts_to_ts_ft(old_ts, old_dt, new_dt, new_n):
npoints_f = (new_bw - old_bw) / (old_bw / old_ts.shape[0])
npoints_f_half = int(np.round(npoints_f / 2))
# scale_factor = np.abs(float(npoints_f_half) * 2.0) / new_n
if npoints_f_half < 0:
# New bandwidth is smaller than old. Truncate
npoints_f_half *= -1
step1 = s2f(old_fs[npoints_f_half:-npoints_f_half])
elif npoints_f_half > 0:
# New bandwidth is larger than old. Pad
npoints_f_half
step1 = s2f(np.pad(old_fs, (npoints_f_half, npoints_f_half), 'constant', constant_values=(0j, 0j)))
step1 = s2f(np.pad(old_fs, ((npoints_f_half, npoints_f_half), (0, 0)), 'constant', constant_values=(0j, 0j)))
else:
step1 = s2f(old_fs)
# Scaling for different length fft/ifft
step1 = step1 * step1.shape[0] / old_fs.shape[0]
# Step 2: pad or truncate in the temporal domain
if step1.shape[0] < new_n:
step2 = np.pad(step1, (0, new_n - step1.shape[0]), 'constant', constant_values=(0j, 0j))
step2 = np.pad(step1, ((0, new_n - step1.shape[0]), (0, 0)), 'constant', constant_values=(0j, 0j))
else:
step2 = step1[:new_n]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment