Commit 6408e80a authored by William Clarke's avatar William Clarke
Browse files

Add merge functionality

parent e07fc8a1
......@@ -14,7 +14,9 @@ from fsl_mrs.utils.preproc import nifti_mrs_tools as nmrs_tools
testsPath = Path(__file__).parent
test_data_split = testsPath / 'testdata' / 'fsl_mrs_preproc' / 'metab_raw.nii.gz'
test_data_merge = testsPath / 'testdata' / 'fsl_mrs_preproc' / 'wref_raw.nii.gz'
test_data_merge_1 = testsPath / 'testdata' / 'fsl_mrs_preproc' / 'wref_raw.nii.gz'
test_data_merge_2 = testsPath / 'testdata' / 'fsl_mrs_preproc' / 'quant_raw.nii.gz'
test_data_other = testsPath / 'testdata' / 'fsl_mrs_preproc' / 'ecc.nii.gz'
def test_split():
......@@ -114,3 +116,73 @@ def test_split():
test_list = np.delete(test_list, [0, 32, 63])
assert np.allclose(out_1.data, nmrs.data[:, :, :, :, :, test_list])
assert np.allclose(out_2.data, nmrs.data[:, :, :, :, :, [0, 32, 63]])
def test_merge():
"""Test the merge functionality
"""
nmrs_1 = mrs_io.read_FID(test_data_merge_1)
nmrs_2 = mrs_io.read_FID(test_data_merge_2)
nmrs_bad_shape, _ = nmrs_tools.split(nmrs_2, 'DIM_COIL', 4)
nmrs_no_tag = mrs_io.read_FID(test_data_other)
# Error testing
# Wrong dim tag
with pytest.raises(ValueError) as exc_info:
nmrs_tools.merge((nmrs_1, nmrs_2), 'DIM_EDIT')
assert exc_info.type is ValueError
assert exc_info.value.args[0] == "DIM_EDIT not found as dimension tag."\
" This data contains ['DIM_COIL', 'DIM_DYN', None]."
# Wrong dim index (no dim in this data)
with pytest.raises(ValueError) as exc_info:
nmrs_tools.merge((nmrs_1, nmrs_2), 6)
assert exc_info.type is ValueError
assert exc_info.value.args[0] == "Dimension must be one of 4, 5, or 6 (or DIM_TAG string)."\
" This data has 6 dimensions,"\
" i.e. a maximum dimension value of 5."
# Wrong dim index (too low)
with pytest.raises(ValueError) as exc_info:
nmrs_tools.merge((nmrs_1, nmrs_2), 3)
assert exc_info.type is ValueError
assert exc_info.value.args[0] == "Dimension must be one of 4, 5, or 6 (or DIM_TAG string)."\
" This data has 6 dimensions,"\
" i.e. a maximum dimension value of 5."
# Wrong dim index type
with pytest.raises(TypeError) as exc_info:
nmrs_tools.merge((nmrs_1, nmrs_2), [3, ])
assert exc_info.type is TypeError
assert exc_info.value.args[0] == "Dimension must be an int (4, 5, or 6) or string (DIM_TAG string)."
# Incompatible shapes
with pytest.raises(nmrs_tools.NIfTI_MRSIncompatible) as exc_info:
nmrs_tools.merge((nmrs_1, nmrs_bad_shape), 'DIM_DYN')
assert exc_info.type is nmrs_tools.NIfTI_MRSIncompatible
assert exc_info.value.args[0] == "The shape of all concatentated objects must match. "\
"The shape ((1, 1, 1, 4096, 4, 2)) of the 1 object does "\
"not match that of the first ((1, 1, 1, 4096, 32, 2))."
# Incompatible tags
with pytest.raises(nmrs_tools.NIfTI_MRSIncompatible) as exc_info:
nmrs_tools.merge((nmrs_1, nmrs_no_tag), 'DIM_DYN')
assert exc_info.type is nmrs_tools.NIfTI_MRSIncompatible
assert exc_info.value.args[0] == "The tags of all concatentated objects must match. "\
"The tags (['DIM_COIL', None, None]) of the 1 object does "\
"not match that of the first (['DIM_COIL', 'DIM_DYN', None])."
# Functionality testing
out = nmrs_tools.merge((nmrs_1, nmrs_2), 'DIM_DYN')
assert out.data.shape == (1, 1, 1, 4096, 32, 4)
assert np.allclose(out.data[:, :, :, :, :, 0:2], nmrs_1.data)
assert np.allclose(out.data[:, :, :, :, :, 2:], nmrs_2.data)
assert out.hdr_ext == nmrs_1.hdr_ext
assert np.allclose(out.getAffine('voxel', 'world'), nmrs_1.getAffine('voxel', 'world'))
......@@ -14,7 +14,8 @@ def split(nmrs, dimension, index_or_indicies):
:param nmrs: Input nifti_mrs object to split
:type nmrs: fsl_mrs.core.nifti_mrs.NIFTI_MRS
:param dimension: Dimension tag or one of 4, 5, 6 (for 0-indexed 5th, 6th, and 7th)
:param dimension: Dimension along which to split.
Dimension tag or one of 4, 5, 6 (for 0-indexed 5th, 6th, and 7th)
:type dimension: str or int
:param index_or_indicies: Single integer index to split after,
or list of interger indices to insert into second array.
......@@ -63,3 +64,66 @@ def split(nmrs, dimension, index_or_indicies):
nmrs_2 = NIFTI_MRS(np.take(nmrs.data, index, axis=dim_index), header=nmrs.header)
return nmrs_1, nmrs_2
class NIfTI_MRSIncompatible(Exception):
pass
def merge(array_of_nmrs, dimension):
"""Concatenate NIfTI-MRS objects along specified higher dimension
:param array_of_nmrs: Array of NIFTI-MRS objects to concatenate
:type array_of_nmrs: tuple or list of fsl_mrs.core.nifti_mrs.NIFTI_MRS
:param dimension: Dimension along which to concatenate.
Dimension tag or one of 4, 5, 6 (for 0-indexed 5th, 6th, and 7th).
:type dimension: int or str
:return: Concatenated NIFTI-MRS object
:rtype: fsl_mrs.core.nifti_mrs.NIFTI_MRS
"""
if isinstance(dimension, str):
try:
dim_index = array_of_nmrs[0].dim_position(dimension)
except NIFTIMRS_DimDoesntExist:
raise ValueError(f'{dimension} not found as dimension tag. This data contains {array_of_nmrs[0].dim_tags}.')
elif isinstance(dimension, int):
if dimension > (array_of_nmrs[0].ndim - 1) or dimension < 4:
raise ValueError('Dimension must be one of 4, 5, or 6 (or DIM_TAG string).'
f' This data has {array_of_nmrs[0].ndim} dimensions,'
f' i.e. a maximum dimension value of {array_of_nmrs[0].ndim-1}.')
dim_index = dimension
else:
raise TypeError('Dimension must be an int (4, 5, or 6) or string (DIM_TAG string).')
# Check shapes and tags are compatible.
# If they are and enter the data into a tuple for concatenation
def check_shape(to_compare):
for dim in range(to_compare.ndim):
# Do not compare on selected dimension
if dim == dim_index:
continue
if to_compare.shape[dim] != array_of_nmrs[0].shape[dim]:
return False
return True
def check_tag(to_compare):
for tdx in range(3):
if array_of_nmrs[0].dim_tags[tdx] != to_compare.dim_tags[tdx]:
return False
return True
to_concat = []
for idx, nmrs in enumerate(array_of_nmrs):
# Check shape
if not check_shape(nmrs):
raise NIfTI_MRSIncompatible('The shape of all concatentated objects must match.'
f' The shape ({nmrs.shape}) of the {idx} object does'
f' not match that of the first ({array_of_nmrs[0].shape}).')
if not check_tag(nmrs):
raise NIfTI_MRSIncompatible('The tags of all concatentated objects must match.'
f' The tags ({nmrs.dim_tags}) of the {idx} object does'
f' not match that of the first ({array_of_nmrs[0].dim_tags}).')
# Check dim tags for compatibility
to_concat.append(nmrs.data)
return NIFTI_MRS(np.concatenate(to_concat, axis=dim_index), header=array_of_nmrs[0].header)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment