Commit 659ec713 authored by William Clarke's avatar William Clarke
Browse files

All tests passing.

parent 96bc72d4
from fsl_mrs.core.MRS import MRS
from fsl_mrs.core.MRSI import MRSI
from fsl_mrs.core.NIFTI_MRS import NIFTI_MRS, is_nifti_mrs
from .mrs import MRS
from .mrsi import MRSI
from .nifti_mrs import NIFTI_MRS
from .utility import mrs_from_files, is_nifti_mrs, mrsi_from_files
......@@ -9,7 +9,6 @@
import warnings
from fsl_mrs.utils import mrs_io as io
from fsl_mrs.utils import misc
from fsl_mrs.utils.constants import GYRO_MAG_RATIO, PPM_SHIFT, PPM_RANGE
......@@ -59,22 +58,6 @@ class MRS(object):
self.metab_groups = None
self.scaling = {'FID': 1.0, 'basis': 1.0}
def from_files(self, FID_file, Basis_file, H2O_file=None):
'''Load data from files into empty MRS class object'''
FID, FIDheader = io.read_FID(FID_file)
basis, names, Bheader = io.read_basis(Basis_file)
if H2O_file is not None:
H2O, _ = io.read_FID(H2O_file)
else:
H2O = None
MRSArgs = {'header': FIDheader,
'basis': basis, 'basis_hdr': Bheader[0],
'names': names}
self.__init__(FID=FID, H2O=H2O, **MRSArgs)
def __str__(self):
cf_MHz = self.centralFrequency / 1e6
cf_T = self.centralFrequency / self.gyromagnetic_ratio / 1e6
......@@ -167,7 +150,7 @@ class MRS(object):
(cf_MHz > sevent_range[0] and cf_MHz < sevent_range[1]) or \
(cf_MHz > ninefourt_range[0] and cf_MHz < ninefourt_range[1]) or \
(cf_MHz > elevensevent_range[0] and cf_MHz < elevensevent_range[1]):
#print(f'Identified as {key} nucleus data.'
# print(f'Identified as {key} nucleus data.'
# f' Esitmated field: {cf_MHz/GYRO_MAG_RATIO[key]} T.')
return key
......
......@@ -9,11 +9,10 @@
# SHBASECOPYRIGHT
import numpy as np
from fsl_mrs.core import MRS
from fsl_mrs.utils import mrs_io, misc
import matplotlib.pyplot as plt
import nibabel as nib
from fsl_mrs.utils.mrs_io import fsl_io
from fsl_mrs.core import MRS
from fsl_mrs.utils import misc
class MRSI(object):
......@@ -267,8 +266,9 @@ class MRSI(object):
self.gm = gm
self.tissue_seg_loaded = True
def write_output(self, data_list, file_path_name, indicies=None, cleanup=True, dtype=float):
'''Write 3D or 4D array of data to nifti file with current orientation.'''
def list_to_matched_array(self, data_list, indicies=None, cleanup=True, dtype=float):
'''Convert 3D or 4D array of data indexed from an mrsi object
to a numpy array matching the shape of the mrsi data.'''
if indicies is None:
indicies = self.get_indicies_in_order()
......@@ -287,54 +287,4 @@ class MRSI(object):
data[data < 1e-10] = 0
data[data > 1e10] = 0
if nt == self.FID_points:
fsl_io.saveNIFTI(file_path_name, data, self.header)
else:
img = nib.Nifti1Image(data, self.header['nifti'].affine)
nib.save(img, file_path_name)
@classmethod
def from_files(cls, data_file,
mask_file=None,
basis_file=None,
H2O_file=None,
csf_file=None,
gm_file=None,
wm_file=None):
""" Load MRSI data directly from files """
data, hdr = mrs_io.read_FID(data_file)
if mask_file is not None:
nib_img = nib.load(mask_file)
mask = np.asanyarray(nib_img.dataobj)
else:
mask = None
if basis_file is not None:
basis, names, basisHdr = mrs_io.read_basis(basis_file)
else:
basis, names, basisHdr = None, None, [None, ]
if H2O_file is not None:
data_w, hdr_w = mrs_io.read_FID(H2O_file)
else:
data_w = None
out = cls(data, hdr,
mask=mask,
basis=basis,
names=names,
basis_hdr=basisHdr[0],
H2O=data_w)
def loadNii(f):
nii = np.asanyarray(nib.load(f).dataobj)
if nii.ndim == 2:
nii = np.expand_dims(nii, 2)
return nii
if (csf_file is not None) and (gm_file is not None) and (wm_file is not None):
csf = loadNii(csf_file)
gm = loadNii(gm_file)
wm = loadNii(wm_file)
out.set_tissue_seg(csf, wm, gm)
return out
return data
......@@ -10,27 +10,10 @@ import numpy as np
import json
from fsl.data.image import Image
from fsl_mrs.core import MRS, MRSI
import fsl.utils.path as fslpath
from nibabel.nifti2 import Nifti2Header
from nibabel.nifti1 import Nifti1Extension
from fsl_mrs.utils.misc import checkCFUnits
class NIFTIMRS_DimDoesntExist(Exception):
pass
class NotNIFTI_MRS(Exception):
pass
def is_nifti_mrs(file_path):
'''Check that a file is of the NIFTI-MRS format type.'''
try:
NIFTI_MRS(file_path)
return True
except fslpath.PathError:
raise NotNIFTI_MRS("File isn't NIFTI-MRS, wrong extension type.")
from nibabel.nifti2 import Nifti2Header
from fsl_mrs.utils.misc import checkCFUnits
def gen_new_nifti_mrs(data, dwelltime, spec_freq, nucleus='1H', affine=None, dim_tags=[None, None, None]):
......@@ -46,7 +29,7 @@ def gen_new_nifti_mrs(data, dwelltime, spec_freq, nucleus='1H', affine=None, dim
:return: NIFTI_MRS object
'''
if not np.iscomplex(data).all():
if not np.iscomplexobj(data):
raise ValueError('data must be complex')
if data.ndim < 4 or data.ndim > 7:
raise ValueError(f'data must between 4 and 7 dimensions, currently has {data.ndim}')
......@@ -74,6 +57,14 @@ def gen_new_nifti_mrs(data, dwelltime, spec_freq, nucleus='1H', affine=None, dim
return NIFTI_MRS(data, header=header)
class NIFTIMRS_DimDoesntExist(Exception):
pass
class NotNIFTI_MRS(Exception):
pass
class NIFTI_MRS(Image):
"""Load NIFTI MRS format data. Derived from nibabel's Nifti2Image."""
def __init__(self, *args, **kwargs):
......@@ -304,7 +295,15 @@ class NIFTI_MRS(Image):
basis_hdr = basis_hdr[0]
if ref_data is not None:
ref_data = ref_data.data
if isinstance(ref_data, str):
ref_data = NIFTI_MRS(ref_data).data
elif isinstance(ref_data, NIFTI_MRS):
ref_data = ref_data.data
elif isinstance(ref_data, np.ndarray):
pass
else:
raise TypeError('ref_data must be a path to a NIFTI-MRS file,'
'a NIFTI_MRS object, or a numpy array.')
for data, _ in self.iterate_over_dims(dim=dim):
if np.prod(data.shape[:3]) > 1:
......@@ -332,6 +331,9 @@ class NIFTI_MRS(Image):
basis_hdr=basis_hdr,
H2O=ref_data)
else:
if ref_data is not None:
ref_data = ref_data.squeeze()
# Generate MRS objects
if data.ndim > 4:
out = []
......
'''utility.py -Module containing utility functions
for creating MRS, MRSI and NIFTI_MRS objects
Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
Will Clarke <william.clarke@ndcn.ox.ac.uk>
Copyright (C) 2021 University of Oxford
# SHBASECOPYRIGHT
'''
import nibabel as nib
import numpy as np
import fsl.utils.path as fslpath
from fsl_mrs.utils import mrs_io
from fsl_mrs.core.nifti_mrs import NIFTI_MRS, NotNIFTI_MRS
def mrs_from_files(FID_file, Basis_file, H2O_file=None):
'''Construct an MRS object from FID, basis, and
(optionally) a reference file
:param FID_file: path to data file
:param Basis_file: path to basis file
:param H2O_file: Optional path to reference file
:return mrs: MRS object
'''
FID = mrs_io.read_FID(FID_file)
basis, names, Bheader = mrs_io.read_basis(Basis_file)
if H2O_file is not None:
H2O = mrs_io.read_FID(H2O_file).data
else:
H2O = None
return FID.mrs(basis=basis, names=names, basis_hdr=Bheader[0], ref_data=H2O)
def mrsi_from_files(data_file,
mask_file=None,
basis_file=None,
H2O_file=None,
csf_file=None,
gm_file=None,
wm_file=None):
'''Construct an MRS object from data, and
(optionally) basis, mask, reference and segmentation files
:param FID_file: path to data file
:param mask_file: Optional path to basis file
:param basis_file: Optional path to reference file
:param H2O_file: Optional path to reference file
:param csf_file: Optional path to reference file
:param gm_file: Optional path to reference file
:param wm_file: Optional path to reference file
:return mrs: MRSI object
'''
data = mrs_io.read_FID(data_file)
if mask_file is not None:
nib_img = nib.load(mask_file)
mask = np.asanyarray(nib_img.dataobj)
else:
mask = None
if basis_file is not None:
basis, names, basisHdr = mrs_io.read_basis(basis_file)
else:
basis, names, basisHdr = None, None, [None, ]
if H2O_file is not None:
data_w = mrs_io.read_FID(H2O_file)
else:
data_w = None
out = data.mrs(basis=basis, names=names, basis_hdr=basisHdr[0], ref_data=data_w)
out.set_mask(mask)
def loadNii(f):
nii = np.asanyarray(nib.load(f).dataobj)
if nii.ndim == 2:
nii = np.expand_dims(nii, 2)
return nii
if (csf_file is not None) and (gm_file is not None) and (wm_file is not None):
csf = loadNii(csf_file)
gm = loadNii(gm_file)
wm = loadNii(wm_file)
out.set_tissue_seg(csf, wm, gm)
return out
def is_nifti_mrs(file_path):
'''Check that a file is of the NIFTI-MRS format type.'''
try:
NIFTI_MRS(file_path)
return True
except fslpath.PathError:
raise NotNIFTI_MRS("File isn't NIFTI-MRS, wrong extension type.")
......@@ -101,10 +101,6 @@ def main():
' Defaults to tCr (Cr+PCr).')
optional.add_argument('--h2o_scale', type=float, default=1.0,
help='Additional scaling modifier for external water referencing.')
optional.add_argument('--central_frequency', default=None, type=float,
help='central frequency in Hz')
optional.add_argument('--dwell_time', default=None, type=float,
help='dwell time in seconds')
optional.add_argument('--report', action="store_true",
help='output html report')
optional.add_argument('--verbose', action="store_true",
......@@ -143,7 +139,6 @@ def main():
import os
import shutil
import warnings
from fsl_mrs.core import MRS
from fsl_mrs.utils import mrs_io
from fsl_mrs.utils import report
from fsl_mrs.utils import fitting
......@@ -185,49 +180,23 @@ def main():
print(f' {args.data}')
print(f' {args.basis}\n')
FID, dataheader = mrs_io.read_FID(args.data)
FID = mrs_io.read_FID(args.data)
basis, names, basisheader = mrs_io.read_basis(args.basis)
if args.h2o is not None:
H2O, _ = mrs_io.read_FID(args.h2o)
H2O = mrs_io.read_FID(args.h2o)
else:
H2O = None
# Collect useful info
if args.central_frequency is not None:
cf = args.central_frequency
elif dataheader['centralFrequency'] is not None:
cf = dataheader['centralFrequency']
if args.verbose:
print(f' Detected central frequency'
f' in header info cf = {cf:0.6f} MHz')
else:
raise(Exception('Cannot determine central frequency.'
'Please either set it or include it in data header'))
if args.dwell_time is not None:
bw = 1 / args.dwell_time
elif dataheader['bandwidth'] is not None:
bw = dataheader['bandwidth']
if args.verbose:
print(f' Detected bandwidth in header info bw = {bw:0.1f} Hz')
else:
raise(Exception('Cannot determine bandwidth.'
'Please either set it or include it in data header'))
# Fix case where basis file contains no header info (e.g. .RAW)
if basisheader is None:
basisheader = {'bandwidth': bw,
'dwelltime': 1 / bw,
'centralFrequency': cf}
else:
basisheader = basisheader[0]
# Instantiate MRS object
MRSargs = {'FID': FID, 'basis': basis,
'basis_hdr': basisheader, 'names': names,
'H2O': H2O, 'cf': cf, 'bw': bw}
mrs = MRS(**MRSargs)
mrs = FID.mrs(basis=basis,
names=names,
basis_hdr=basisheader[0],
ref_data=H2O)
if isinstance(mrs, list):
raise ValueError('fsl_mrs only handles a single FID at a time.'
' Please preprocess data.')
# Check the FID and basis / conjugate
if args.conjfid is not None:
......@@ -268,7 +237,7 @@ def main():
if args.verbose:
print('--->> Start fitting\n\n')
print(' Algorithm = [{}]\n'.format(args.algo))
start = time.time()
start = time.time()
ppmlim = args.ppmlim
if ppmlim is not None:
......@@ -312,15 +281,14 @@ def main():
# Echo time
if args.TE is not None:
echotime = args.TE * 1E-3
elif 'meta' in basisheader:
if 'TE' in basisheader['meta']:
echotime = basisheader['meta']['TE']
if echotime > 1.0: # Assume in ms.
echotime *= 1E-3
else:
echotime = None
elif 'TE' in dataheader:
echotime = dataheader['TE']
elif 'meta' in basisheader and 'TE' in basisheader['meta']:
echotime = basisheader['meta']['TE']
if echotime > 1.0: # Assume in ms.
echotime *= 1E-3
else:
echotime = None
elif 'EchoTime' in FID.hdr_ext:
echotime = FID.hdr_ext['TE']
else:
echotime = None
......
......@@ -62,8 +62,6 @@ def main():
help='spit out verbose info')
optional.add_argument('--conjugate', action="store_true",
help='apply conjugate to FID')
optional.add_argument('--no_conjugate', action="store_true",
help='Forbid automatic conjugation')
optional.add_argument('--overwrite', action="store_true",
help='overwrite existing output folder')
optional.add_argument('--report', action="store_true",
......
......@@ -11,7 +11,7 @@ from fsl_mrs.auxiliary import configargparse
from fsl_mrs import __version__
from fsl_mrs.utils.splash import splash
from fsl_mrs.utils import fitting, misc
from fsl_mrs.utils import fitting, misc, mrs_io
import time
# NOTE!!!! THERE ARE MORE IMPORTS IN THE CODE BELOW (AFTER ARGPARSING)
......@@ -124,8 +124,8 @@ def main():
import shutil
import warnings
import numpy as np
from fsl_mrs import core as fslmrs
from fsl_mrs.utils import report
from fsl_mrs.core import NIFTI_MRS
import datetime
import nibabel as nib
from functools import partial
......@@ -158,9 +158,14 @@ def main():
# ###### Do the work #######
# Read files
mrsi = fslmrs.MRSI.from_files(args.data,
basis_file=args.basis,
H2O_file=args.h2o)
mrsi_data = mrs_io.read_FID(args.data)
if args.h2o is not None:
H2O = mrs_io.read_FID(args.h2o)
else:
H2O = None
mrsi = mrsi_data.mrs(basis_file=args.basis,
ref_data=H2O)
def loadNii(f):
nii = np.asanyarray(nib.load(f).dataobj)
......@@ -207,13 +212,12 @@ def main():
# Echo time
if args.TE is not None:
echotime = args.TE * 1E-3
elif 'meta' in mrsi.basis_hdr:
if 'TE' in mrsi.basis_hdr['meta']:
echotime = mrsi.basis_hdr['meta']['TE']
if echotime > 1.0: # Assume in ms.
echotime *= 1E-3
else:
echotime = None
elif 'meta' in mrsi.basis_hdr and 'TE' in mrsi.basis_hdr['meta']:
echotime = mrsi.basis_hdr['meta']['TE']
if echotime > 1.0: # Assume in ms.
echotime *= 1E-3
else:
echotime = None
elif 'TE' in mrsi.header:
echotime = mrsi.header['TE']
else:
......@@ -306,6 +310,13 @@ def main():
if results[0][0].concScalings['molality'] is not None:
scalings.append('molality')
def save_img_output(fname, data):
if data.ndim > 3 and data.shape[3] == mrsi.FID_points:
NIFTI_MRS(data, header=mrsi_data.header).save(fname)
else:
img = nib.Nifti1Image(data, mrsi_data.voxToWorldMat)
nib.save(img, fname)
metabs = results[0][0].metabs
for scale in scalings:
cur_fldr = os.path.join(concs_folder, scale)
......@@ -314,42 +325,46 @@ def main():
metab_conc_list = [res[0].getConc(scaling=scale, metab=metab)
for res in results]
file_nm = os.path.join(cur_fldr, metab + '.nii.gz')
mrsi.write_output(metab_conc_list,
file_nm,
indicies=indicies,
cleanup=True,
dtype=float)
save_img_output(file_nm,
mrsi.list_to_matched_array(
metab_conc_list,
indicies=indicies,
cleanup=True,
dtype=float))
# Uncertainties
for metab in results[0][0].metabs:
metab_sd_list = [res[0].getUncertainties(metab=metab)
for res in results]
file_nm = os.path.join(uncer_folder, metab + '_sd.nii.gz')
mrsi.write_output(metab_sd_list,
file_nm,
indicies=indicies,
cleanup=True,
dtype=float)
save_img_output(file_nm,
mrsi.list_to_matched_array(
metab_sd_list,
indicies=indicies,
cleanup=True,
dtype=float))
# qc - SNR & FWHM
for metab in results[0][0].original_metabs:
metab_fwhm_list = [res[0].getQCParams(metab=metab)[1]
for res in results]
file_nm = os.path.join(qc_folder, metab + '_fwhm.nii.gz')
mrsi.write_output(metab_fwhm_list,
file_nm,
indicies=indicies,
cleanup=True,
dtype=float)
save_img_output(file_nm,
mrsi.list_to_matched_array(
metab_fwhm_list,
indicies=indicies,
cleanup=True,
dtype=float))
metab_snr_list = [res[0].getQCParams(metab=metab)[0]
for res in results]
file_nm = os.path.join(qc_folder, metab + '_snr.nii.gz')
mrsi.write_output(metab_snr_list,
file_nm,
indicies=indicies,
cleanup=True,
dtype=float)
save_img_output(file_nm,
mrsi.list_to_matched_array(
metab_snr_list,
indicies=indicies,
cleanup=True,
dtype=float))
# fit
# TODO: check if data has been conjugated, if so conjugate the predictions
......@@ -358,31 +373,34 @@ def main():
for res, scale in zip(results, mrs_scale):
pred_list.append(res[0].pred / scale['FID'])
file_nm = os.path.join(fit_folder, 'fit.nii.gz')
mrsi.write_output(pred_list,
file_nm,
indicies=indicies,
cleanup=False,
dtype=np.complex128)
save_img_output(file_nm,
mrsi.list_to_matched_array(
pred_list,
indicies=indicies,
cleanup=False,
dtype=np.complex64))
res_list = []
for res, scale in zip(results, mrs_scale):
res_list.append(res[0].residuals / scale['FID'])
file_nm = os.path.join(fit_folder, 'residual.nii.gz')
mrsi.write_output(res_list,
file_nm,
indicies=indicies,
cleanup=False,
dtype=np.complex128)
save_img_output(file_nm,
mrsi.list_to_matched_array(
res_list,
indicies=indicies,
cleanup=False,
dtype=np.complex64))
baseline_list = []
for res, scale in zip(results, mrs_scale):
baseline_list.append(res[0].baseline / scale['FID'])
file_nm = os.path.join(fit_folder, 'baseline.nii.gz')
mrsi.write_output(baseline_list,
file_nm,
indicies=indicies,
cleanup=False,
dtype=np.complex128)
save_img_output(file_nm,
mrsi.list_to_matched_array(
baseline_list,