Commit a661c177 authored by William Clarke's avatar William Clarke
Browse files

Update basis tools add to match other functions.

parent c0e9458b
......@@ -198,6 +198,7 @@ def convert(args):
def add(args):
from fsl_mrs.utils.mrs_io import read_basis
from fsl_mrs.utils import basis_tools
import json
import numpy as np
......@@ -217,13 +218,23 @@ def add(args):
bw = 1 / json_dict['basis_dwell']
width = json_dict['basis_width']
basis_tools.add_basis(
fid, name, cf, bw, args.target,
# Check that target exists
target = Path(args.target)
if not target.is_dir():
raise NotADirectoryError('Target must be a directory of FSL-MRS basis (json) files')
# Load target
target_basis = read_basis(target)
target_basis = basis_tools.add_basis(
fid, name, cf, bw, target_basis,
scale=args.scale,
width=width,
conj=args.conj,
pad=args.pad,
sim_info=args.info)
pad=args.pad)
# Write to json without overwriting existing files
target_basis.save(target, info_str=args.info)
def shift(args):
......
......@@ -6,7 +6,6 @@ Test basis tools.
Copyright Will Clarke, University of Oxford, 2021
'''
from pathlib import Path
from shutil import copytree
import pytest
import numpy as np
......@@ -34,9 +33,8 @@ def test_convert_lcmodel(tmp_path):
assert np.isclose(basis.cf, new_basis.cf)
def test_add_basis(tmp_path):
out_loc = tmp_path / 'test_basis'
copytree(fsl_basis_path, out_loc)
def test_add_basis():
basis = mrs_io.read_basis(fsl_basis_path)
mac_in = fsl_io.readJSON(extra_basis)
fid = np.asarray(mac_in['basis_re']) + 1j * np.asarray(mac_in['basis_im'])
......@@ -44,31 +42,27 @@ def test_add_basis(tmp_path):
bw = 1 / mac_in['basis_dwell']
with pytest.raises(basis_tools.IncompatibleBasisError) as exc_info:
basis_tools.add_basis(fid, 'mac2', cf, bw, out_loc)
basis_tools.add_basis(fid, 'mac2', cf, bw, basis)
assert exc_info.type is basis_tools.IncompatibleBasisError
assert exc_info.value.args[0] == "The new basis FID covers too little time, try padding."
basis_tools.add_basis(fid, 'mac1', cf, bw, out_loc, pad=True)
new_basis = mrs_io.read_basis(out_loc)
new_basis = basis_tools.add_basis(fid, 'mac1', cf, bw, basis, pad=True)
index = new_basis.names.index('mac1')
assert 'mac1' in new_basis.names
fid_pad = np.pad(fid, (0, fid.size))
basis_tools.add_basis(fid_pad, 'mac2', cf, bw, out_loc)
new_basis = mrs_io.read_basis(out_loc)
new_basis = basis_tools.add_basis(fid_pad, 'mac2', cf, bw, basis)
index = new_basis.names.index('mac2')
assert 'mac2' in new_basis.names
assert np.allclose(new_basis._raw_fids[:, index], fid_pad[0::2])
basis_tools.add_basis(fid_pad, 'mac3', cf, bw, out_loc, scale=True, width=10)
new_basis = mrs_io.read_basis(out_loc)
new_basis = basis_tools.add_basis(fid_pad, 'mac3', cf, bw, basis, scale=True, width=10)
index = new_basis.names.index('mac3')
assert 'mac3' in new_basis.names
assert new_basis.basis_fwhm[index] == 10
basis_tools.add_basis(fid_pad, 'mac4', cf, bw, out_loc, width=10, conj=True)
new_basis = mrs_io.read_basis(out_loc)
new_basis = basis_tools.add_basis(fid_pad, 'mac4', cf, bw, basis, width=10, conj=True)
index = new_basis.names.index('mac4')
assert 'mac4' in new_basis.names
assert np.allclose(new_basis._raw_fids[:, index], fid_pad[0::2].conj())
......
......@@ -46,7 +46,7 @@ def convert_lcm_basis(path_to_basis, output_location=None):
basis.save(output_location, info_str=sim_info)
def add_basis(fid, name, cf, bw, target, scale=False, width=None, conj=False, pad=False, sim_info='Manually added'):
def add_basis(fid, name, cf, bw, target, scale=False, width=None, conj=False, pad=False):
"""Add an additional basis spectrum to an existing FSL formatted basis set.
Optionally rescale the norm of the new FID to the mean of the existing ones.
......@@ -60,7 +60,7 @@ def add_basis(fid, name, cf, bw, target, scale=False, width=None, conj=False, pa
:param bw: Bandwidth in Hz
:type bw: float
:param target: Target basis set
:type target: str or pathlib.Path
:type target: fsl_mrs.core.basis.Basis
:param scale: Rescale the fid so its norm is the mean of the norms of the
other basis spectra, defaults to False
:type scale: bool, optional
......@@ -70,60 +70,56 @@ def add_basis(fid, name, cf, bw, target, scale=False, width=None, conj=False, pa
:type conj: bool, optional
:param pad: Pad input FID to target length if required, defaults to False.
:type pad: bool, optional
:param sim_info: String added to the meta.SimVersion field, defaults to 'Manually added'
:type sim_info: str, optional
:return: Modified target basis
:rtype: fsl_mrs.core.basis.Basis
"""
# 1. Check that target exists
target = Path(target)
if not target.is_dir():
raise NotADirectoryError('Target must be a directory of FSL-MRS basis (json) files')
# 2. Resample new basis to the same raster as the target
# 1. Resample new basis to the same raster as the target
# Can't use the central frequency as a way to align as the absolute frequency is effectively arbitrary
target_basis = mrs_io.read_basis(target)
target_dt = target_basis.original_dwell
target_dt = target.original_dwell
try:
resampled_fid = ts_to_ts(fid,
1 / bw,
target_dt,
target_basis.original_points)
target.original_points)
except InsufficentTimeCoverageError:
if not pad:
raise IncompatibleBasisError('The new basis FID covers too little time, try padding.')
else:
# Pad fid to sufficent length
required_time = target_basis.original_points * target_dt
required_time = (target.original_points - 1) * target_dt
fid_dt = 1 / bw
required_points = int(np.ceil(required_time / fid_dt))
required_points = int(np.ceil(required_time / fid_dt)) + 1
fid = np.pad(fid, (0, required_points - fid.size), constant_values=complex(0.0))
resampled_fid = ts_to_ts(fid,
1 / bw,
target_dt,
target_basis.original_points)
target.original_points)
# 3. Scale if requested
# 2. Scale if requested
if scale:
norms = []
for b in target_basis.original_basis_array.T:
for b in target.original_basis_array.T:
norms.append(np.linalg.norm(b))
resampled_fid *= np.mean(norms) / np.linalg.norm(resampled_fid)
# 4. Width calculation if needed.
# 3. Width calculation if needed.
if width is None:
mrs = MRS(FID=resampled_fid, cf=cf, bw=bw)
mrs.check_FID(repair=True)
width, _, _ = idPeaksCalcFWHM(mrs)
# 5. Conjugate if requested
# 4. Conjugate if requested
if conj:
resampled_fid = resampled_fid.conj()
# 6. Add to existing basis
target_basis.add_fid_to_basis(resampled_fid, name, width=width)
# 5. Add to existing basis
target.add_fid_to_basis(resampled_fid, name, width=width)
# 7. Write to json without overwriting existing files
target_basis.save(target, info_str=sim_info)
# 6. Return modified basis
return target
def shift_basis(basis, name, amount):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment