Commit 85ad437f authored by William Clarke's avatar William Clarke
Browse files

Basis tools for add/subtract.

parent aff5916a
......@@ -372,6 +372,18 @@ class Basis:
self._names.append(name)
self._widths.append(width)
def remove_fid_from_basis(self, name):
"""'Permanently' remove a fid from the core basis.
Typically use the keep/ignore syntax for this purpose.
:param name: Name of metabolite/fid to remove
:type name: str
"""
index = self.names.index(name)
self._raw_fids = np.delete(self._raw_fids, index, axis=1)
self._names.pop(index)
self._widths.pop(index)
def add_peak(self, ppm, amp, name, gamma=0.0, sigma=0.0):
"""Add Voigt peak to basis at specified ppm
......
......@@ -133,6 +133,17 @@ def test_add_fid():
assert original.n_metabs == 22
def test_remove_fid():
original = basis_mod.Basis.from_file(fsl_basis_path)
original.remove_fid_from_basis('NAA')
assert len(original.names) == 20
assert 'NAA' not in original.names
assert original.n_metabs == 20
assert len(original.basis_fwhm) == 20
def test_add_peak():
original = basis_mod.Basis.from_file(fsl_basis_path)
......
......@@ -91,3 +91,34 @@ def test_rescale():
indexed_fid = basis.original_basis_array[:, index]
new_scale = np.linalg.norm(indexed_fid)
assert np.isclose(new_scale, 1.0)
basis_on = testsPath / 'testdata' / 'basis_tools' / 'low_res_off'
basis_off = testsPath / 'testdata' / 'basis_tools' / 'low_res_on'
def test_add_sub():
basis_1 = mrs_io.read_basis(basis_off)
basis_2 = mrs_io.read_basis(basis_on)
# Test addition
new = basis_tools.difference_basis_sets(basis_1, basis_2)
assert np.allclose(
new.original_basis_array,
basis_1.original_basis_array + basis_2.original_basis_array)
# Test subtraction
new = basis_tools.difference_basis_sets(basis_1, basis_2, add_or_subtract='sub')
assert np.allclose(
new.original_basis_array,
basis_1.original_basis_array - basis_2.original_basis_array)
# Test missmatched
basis_1.remove_fid_from_basis('NAA')
new = basis_tools.difference_basis_sets(basis_1, basis_2, missing_metabs='ignore')
assert new.n_metabs == 2
with pytest.raises(basis_tools.IncompatibleBasisError) as exc_info:
new = basis_tools.difference_basis_sets(basis_1, basis_2, missing_metabs='raise')
assert exc_info.type is basis_tools.IncompatibleBasisError
assert exc_info.value.args[0] == "NAA does not occur in basis_1."
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
......@@ -11,6 +11,7 @@ import numpy as np
from fsl_mrs.utils import mrs_io
from fsl_mrs.core import MRS
from fsl_mrs.core.basis import Basis
from fsl_mrs.utils.qc import idPeaksCalcFWHM
from fsl_mrs.utils.misc import ts_to_ts, InsufficentTimeCoverageError
......@@ -153,3 +154,50 @@ def rescale_basis(basis, name, target_scale=None):
basis.update_fid(indexed_fid, name)
return basis
def difference_basis_sets(basis_1, basis_2, add_or_subtract='add', missing_metabs='raise'):
"""Add or subtract basis sets to form a set of difference spectra
:param basis_1: Basis set 1
:type basis_1: fsl_mrs.core.basis.Basis
:param basis_2: Basis set 2
:type basis_2: fsl_mrs.core.basis.Basis
:param add_or_subtract: Add ('add') or subtract ('sub') basis sets, defaults to 'add'
:type add_or_subtract: str, optional
:param missing_metabs: Behaviour when mismatched basis sets are found.
It 'raise' a IncompatibleBasisError is raised, if 'ignore' the mismatched
basis will be skipped. Defaults to 'raise'
:type missing_metabs: str, optional
:return: Difference basis spectra
:rtype: fsl_mrs.core.basis.Basis
"""
if missing_metabs == 'raise':
for name in basis_1.names:
if name not in basis_2.names:
raise IncompatibleBasisError(f'{name} does not occur in basis_2.')
for name in basis_2.names:
if name not in basis_1.names:
raise IncompatibleBasisError(f'{name} does not occur in basis_1.')
difference_spec = []
names = []
headers = []
for b1, name in zip(basis_1.original_basis_array.T, basis_1.names):
if name in basis_2.names:
index = basis_2.names.index(name)
if add_or_subtract == 'add':
diff = b1 + basis_2.original_basis_array[:, index]
elif add_or_subtract == 'sub':
diff = b1 - basis_2.original_basis_array[:, index]
difference_spec.append(diff)
names.append(name)
headers.append({'dwelltime': basis_2.original_dwell,
'bandwidth': basis_2.original_bw,
'centralFrequency': basis_2.cf,
'fwhm': basis_2.basis_fwhm[index]})
return Basis(np.asarray(difference_spec), names, headers)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment