Commit c47449b4 authored by William Clarke's avatar William Clarke
Browse files

Finished port of fsl_mrs_proc. Tests needed.

parent c4d41793
......@@ -11,12 +11,28 @@ import numpy as np
import json
from fsl.data.image import Image
from fsl_mrs.core import MRS, MRSI
import fsl.utils.path as fslpath
class NIFTIMRS_DimDoesntExist(Exception):
pass
class NotNIFTI_MRS(Exception):
pass
def is_nifti_mrs(file_path):
'''Check that a file is of the NIFTI-MRS format type.'''
try:
obj = NIFTI_MRS(file_path)
if float(obj.mrs_nifti_version) < 0.2:
raise NotNIFTI_MRS('NIFTI-MRS > v0.2 required.')
return True
except fslpath.PathError:
raise NotNIFTI_MRS("File isn't NIFTI-MRS, wrong extension type.")
class NIFTI_MRS(Image):
"""Load NIFTI MRS format data. Derived from nibabel's Nifti2Image."""
def __init__(self, *args, **kwargs):
......
from fsl_mrs.core.MRS import MRS
from fsl_mrs.core.MRSI import MRSI
from fsl_mrs.core.NIFTI_MRS import NIFTI_MRS
from fsl_mrs.core.NIFTI_MRS import NIFTI_MRS, is_nifti_mrs
This diff is collapsed.
......@@ -32,6 +32,7 @@ def align_FID(mrs, src_FID, tgt_FID, ppmlim=None, shift=True):
-------
array-like
"""
normalisation = np.linalg.norm(tgt_FID)
# Internal functions so they can see globals
def shift_phase_freq(FID, phi, eps, extract=True):
......@@ -45,7 +46,7 @@ def align_FID(mrs, src_FID, tgt_FID, ppmlim=None, shift=True):
eps = p[1] # freq shift
FID = shift_phase_freq(src_FID, phi, eps)
target = extract_spectrum(mrs, tgt_FID, ppmlim=ppmlim, shift=shift)
xx = np.linalg.norm(FID - target)
xx = np.linalg.norm((FID - target) / normalisation)
return xx
x0 = np.array([0, 0])
res = minimize(cf, x0, method='Powell')
......@@ -71,6 +72,8 @@ def align_FID_diff(mrs, src_FID0, src_FID1, tgt_FID, diffType='add', ppmlim=None
-------
array-like
"""
normalisation = np.linalg.norm(tgt_FID)
# Internal functions so they can see globals
def shift_phase_freq(FID0, FID1, phi, eps, extract=True):
sFID = np.exp(-1j * phi) * shift_FID(mrs, FID0, eps)
......@@ -92,9 +95,10 @@ def align_FID_diff(mrs, src_FID0, src_FID1, tgt_FID, diffType='add', ppmlim=None
eps = p[1] # freq shift
FID = shift_phase_freq(src_FID0, src_FID1, phi, eps)
target = extract_spectrum(mrs, tgt_FID, ppmlim=ppmlim, shift=shift)
xx = np.linalg.norm(FID - target)
xx = np.linalg.norm((FID - target) / normalisation)
return xx
x0 = np.array([0, 0])
x0 = np.array([0.0, 0.0])
res = minimize(cf, x0)
phi = res.x[0]
eps = res.x[1]
......@@ -424,8 +428,8 @@ def phase_freq_align_diff_report(inFIDs0,
else:
raise ValueError('diffType must be add or sub.')
meanIn = combine_FIDs(diffFIDListIn.T, 'mean')
meanOut = combine_FIDs(diffFIDListOut.T, 'mean')
meanIn = combine_FIDs(diffFIDListIn, 'mean')
meanOut = combine_FIDs(diffFIDListOut, 'mean')
def toMRSobj(fid):
return MRS(FID=fid, cf=cf, bw=bw, nucleus=nucleus)
......
......@@ -7,18 +7,6 @@
# SHBASECOPYRIGHT
import numpy as np
from dataclasses import dataclass
@dataclass
class datacontainer:
'''Class for keeping track of data and reference data together.'''
data: np.array
dataheader: dict
datafilename: str
reference: np.array = None
refheader: dict = None
reffilename: str = None
def get_target_FID(FIDlist, target='mean'):
......
......@@ -152,8 +152,8 @@ def align(data, dim, target=None, ppmlim=None, niter=2, apodize=10, report=None,
def aligndiff(data,
reference,
dim,
dim_align,
dim_diff,
diff_type,
target=None,
ppmlim=None,
......@@ -162,9 +162,9 @@ def aligndiff(data,
'''Align frequencies of difference spectra across a dimension
specified by a tag.
:param NIFTI_MRS data: Data to align - data modified by alignment
:param NIFTI_MRS reference: Data to align - data not modified
:param str dim: NIFTI-MRS dimension tag
:param NIFTI_MRS data: Data to align
:param str dim_align: NIFTI-MRS dimension tag to align along
:param str dim_diff: NIFTI-MRS dimension across which diffrence is taken.
:param str diff_type: Either 'add' or 'sub'
:param target: Optional target FID
:param ppmlim: ppm search limits.
......@@ -173,17 +173,27 @@ def aligndiff(data,
:return: Combined data in NIFTI_MRS format.
'''
if data.shape[data.dim_position(dim)] != reference.shape[reference.dim_position(dim)]:
raise DimensionsDoNotMatch('Reference and data selected dimension does not match.')
if data.shape[data.dim_position(dim_diff)] != 2:
raise DimensionsDoNotMatch('Diff dimension must be of length 2')
aligned_obj = data.copy()
for dd, idx in data.iterate_over_dims(dim=dim,
diff_index = data.dim_position(dim_diff)
data_0 = []
data_1 = []
index_0 = []
for dd, idx in data.iterate_over_dims(dim=dim_align,
iterate_over_space=True,
reduce_dim_index=True):
reduce_dim_index=False):
if idx[diff_index] == 0:
data_0.append(dd)
index_0.append(idx)
else:
data_1.append(dd)
for d0, d1, idx in zip(data_0, data_1, index_0):
out = preproc.phase_freq_align_diff(
dd.T,
reference[idx].T,
d0.T,
d1.T,
data.bandwidth,
data.spectrometer_frequency[0],
nucleus=data.nucleus[0],
......@@ -191,14 +201,14 @@ def aligndiff(data,
ppmlim=ppmlim,
target=target)
aligned_obj[idx], _, phi, eps = out[0].T, out[1], out[2], out[3]
aligned_obj[idx], _, phi, eps = np.asarray(out[0]).T, out[1], out[2], out[3]
if report and (report_all or first_index(idx)):
from fsl_mrs.utils.preproc.align import phase_freq_align_diff_report
phase_freq_align_diff_report(dd.T,
reference[idx].T,
phase_freq_align_diff_report(d0.T,
d1.T,
aligned_obj[idx].T,
reference[idx].T,
d1.T,
phi,
eps,
data.bandwidth,
......@@ -219,19 +229,26 @@ def ecc(data, reference, report=None, report_all=False):
:return: Corrected data in NIFTI_MRS format.
'''
if data.shape != reference.shape:
raise DimensionsDoNotMatch('Reference and data shape must match.')
if data.shape != reference.shape\
and reference.ndim > 4:
raise DimensionsDoNotMatch('Reference and data shape must match'
' or reference must be single FID.')
corrected_obj = data.copy()
for dd, idx in data.iterate_over_dims(iterate_over_space=True):
corrected_obj[idx] = preproc.eddy_correct(dd, reference[idx])
if data.shape == reference.shape:
ref = reference[idx]
else:
ref = reference[idx[0], idx[1], idx[2], :]
corrected_obj[idx] = preproc.eddy_correct(dd, ref)
if report and (report_all or first_index(idx)):
from fsl_mrs.utils.preproc.eddycorrect import eddy_correct_report
eddy_correct_report(dd,
corrected_obj[idx],
reference[idx],
ref,
data.bandwidth,
data.spectrometer_frequency[0],
nucleus=data.nucleus[0],
......@@ -480,7 +497,7 @@ def remove_unlike(data, ppmlim=None, sdlimit=1.96, niter=2, report=None):
if data.ndim > 5:
raise ValueError('remove_unlike only makes sense for a single dynamic dimension. Combined coils etc. first')
elif data.ndim < 4:
elif data.ndim < 5:
raise ValueError('remove_unlike only makes sense for data with a dynamic dimension')
goodFIDs, badFIDs, gIndicies, bIndicies, metric = \
......@@ -617,9 +634,9 @@ def subtract(data0, data1=None, dim=None, report=None, report_all=False):
if dim is not None:
# Check dim is of correct size
if data0.shape[data0.dim_tags.index(dim)] != 2:
if data0.shape[data0.dim_position(dim)] != 2:
raise DimensionsDoNotMatch('Subtraction dimension must be of length 2.'
f' Currently {data0.shape[data0.dim_tags.index(dim)]}')
f' Currently {data0.shape[data0.dim_position(dim)]}')
sub_ob = data0.copy(remove_dim=dim)
for dd, idx in data0.iterate_over_dims(dim=dim,
......@@ -628,17 +645,16 @@ def subtract(data0, data1=None, dim=None, report=None, report_all=False):
sub_ob[idx] = preproc.subtract(dd.T[0], dd.T[1])
if report and (report_all or first_index(idx)):
from fsl_mrs.utils.preproc.general import generic_report
original_hdr = {'bandwidth': data0.bandwidth,
'centralFrequency': data0.spectrometer_frequency[0],
'ResonantNucleus': data0.nucleus[0]}
generic_report(dd,
sub_ob[idx],
original_hdr,
original_hdr,
ppmlim=(0.2, 4.2),
html=report,
function='subtract')
from fsl_mrs.utils.preproc.general import add_subtract_report
add_subtract_report(dd.T[0],
dd.T[1],
sub_ob[idx],
data0.bandwidth,
data0.spectrometer_frequency[0],
nucleus=data0.nucleus[0],
ppmlim=(0.2, 4.2),
html=report,
function='subtract')
elif data1 is not None:
......@@ -666,9 +682,9 @@ def add(data0, data1=None, dim=None, report=None, report_all=False):
if dim is not None:
# Check dim is of correct size
if data0.shape[data0.dim_tags.index(dim)] != 2:
if data0.shape[data0.dim_position(dim)] != 2:
raise DimensionsDoNotMatch('Addition dimension must be of length 2.'
f' Currently {data0.shape[data0.dim_tags.index(dim)]}')
f' Currently {data0.shape[data0.dim_position(dim)]}')
add_ob = data0.copy(remove_dim=dim)
for dd, idx in data0.iterate_over_dims(dim=dim,
......@@ -677,17 +693,16 @@ def add(data0, data1=None, dim=None, report=None, report_all=False):
add_ob[idx] = preproc.add(dd.T[0], dd.T[1])
if report and (report_all or first_index(idx)):
from fsl_mrs.utils.preproc.general import generic_report
original_hdr = {'bandwidth': data0.bandwidth,
'centralFrequency': data0.spectrometer_frequency[0],
'ResonantNucleus': data0.nucleus[0]}
generic_report(dd,
add_ob[idx],
original_hdr,
original_hdr,
ppmlim=(0.2, 4.2),
html=report,
function='add')
from fsl_mrs.utils.preproc.general import add_subtract_report
add_subtract_report(dd.T[0],
dd.T[1],
add_ob[idx],
data0.bandwidth,
data0.spectrometer_frequency[0],
nucleus=data0.nucleus[0],
ppmlim=(0.2, 4.2),
html=report,
function='add')
elif data1 is not None:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment