Commit 77cb289d authored by William Clarke's avatar William Clarke
Browse files

MRSI tests, fixed bug with duplicate absis names, flake8 compatible in core,...

MRSI tests, fixed bug with duplicate absis names, flake8 compatible in core, flake8 options in setup.cfg.
parent 23d58004
......@@ -30,7 +30,7 @@ By default this option will add the following basis spectra (in separate metabol
Additional peaks may be added int he interactive environment by calling :code:`add_MM_peaks` with optional arguments to override the defaults.
References
==========
~~~~~~~~~~
.. [CUDA12] Cudalbu C, Mlynárik V, Gruetter R. Handling Macromolecule Signals in the Quantification of the Neurochemical Profile. Journal of Alzheimer’s Disease 2012;31:S101–S115 doi: 10.3233/JAD-2012-120100.
......
......@@ -13,8 +13,7 @@ import warnings
from fsl_mrs.utils import mrs_io as io
from fsl_mrs.utils import misc
from fsl_mrs.utils.constants import GYRO_MAG_RATIO,\
PPM_SHIFT, PPM_RANGE
from fsl_mrs.utils.constants import GYRO_MAG_RATIO, PPM_SHIFT, PPM_RANGE
import numpy as np
......@@ -79,8 +78,8 @@ class MRS(object):
self.__init__(FID=FID, H2O=H2O, **MRSArgs)
def __str__(self):
cf_MHz = self.centralFrequency/1e6
cf_T = self.centralFrequency/self.gyromagnetic_ratio/1e6
cf_MHz = self.centralFrequency / 1e6
cf_T = self.centralFrequency / self.gyromagnetic_ratio / 1e6
out = '------- MRS Object ---------\n'
out += f' FID.shape = {self.FID.shape}\n'
......@@ -116,7 +115,7 @@ class MRS(object):
# Store CF in Hz
self.centralFrequency = misc.checkCFUnits(centralFrequency)
self.bandwidth = bandwidth
self.dwellTime = 1/self.bandwidth
self.dwellTime = 1 / self.bandwidth
def set_acquisition_params_basis(self, dwelltime):
"""
......@@ -124,7 +123,7 @@ class MRS(object):
"""
# Basis has different dwelltime
self.basis_dwellTime = dwelltime
self.basis_bandwidth = 1/dwelltime
self.basis_bandwidth = 1 / dwelltime
axes = misc.calculateAxes(self.basis_bandwidth,
self.centralFrequency,
......@@ -156,11 +155,11 @@ class MRS(object):
@staticmethod
def infer_nucleus(cf):
cf_MHz = cf/1e6
cf_MHz = cf / 1e6
for key in GYRO_MAG_RATIO:
onefivet_range = GYRO_MAG_RATIO[key]*np.asarray([1.445, 1.505])
threet_range = GYRO_MAG_RATIO[key]*np.asarray([2.885, 3.005])
sevent_range = GYRO_MAG_RATIO[key]*np.asarray([6.975, 7.005])
onefivet_range = GYRO_MAG_RATIO[key] * np.asarray([1.445, 1.505])
threet_range = GYRO_MAG_RATIO[key] * np.asarray([2.885, 3.005])
sevent_range = GYRO_MAG_RATIO[key] * np.asarray([6.975, 7.005])
if (cf_MHz > onefivet_range[0] and cf_MHz < onefivet_range[1]) or \
(cf_MHz > threet_range[0] and cf_MHz < threet_range[1]) or \
(cf_MHz > sevent_range[0] and cf_MHz < sevent_range[1]):
......@@ -237,9 +236,9 @@ class MRS(object):
if ppmlim is not None:
def ppm2range(x, shift):
if shift:
return np.argmin(np.abs(self.ppmAxisShift-x))
return np.argmin(np.abs(self.ppmAxisShift - x))
else:
return np.argmin(np.abs(self.ppmAxis-x))
return np.argmin(np.abs(self.ppmAxis - x))
first = ppm2range(ppmlim[0], shift)
last = ppm2range(ppmlim[1], shift)
......@@ -261,7 +260,7 @@ class MRS(object):
self.dwellTime,
self.numPoints)
self.basis_dwellTime = self.dwellTime
self.basis_bandwidth = 1/self.dwellTime
self.basis_bandwidth = 1 / self.dwellTime
self.numBasisPoints = self.numPoints
def processForFitting(self, ppmlim=(.2, 4.2), ind_scaling=None):
......@@ -293,14 +292,14 @@ class MRS(object):
mask = np.zeros_like(self.names, dtype=bool)
mask[index] = True
self.basis[:, ~mask], scaling_basis = misc.rescale_FID(
self.basis[:, ~mask],
scale=scale)
self.basis[:, ~mask],
scale=scale)
scaling_basis = [scaling_basis]
# Then loop over basis spec to independently scale
for idx in index:
self.basis[:, idx], tmp = misc.rescale_FID(
self.basis[:, idx],
scale=scale)
self.basis[:, idx],
scale=scale)
scaling_basis.append(tmp)
else:
scaling_basis = None
......@@ -370,7 +369,7 @@ class MRS(object):
else:
conjOrNot.append(0.0)
if (sum(conjOrNot)/len(conjOrNot)) > 0.5:
if (sum(conjOrNot) / len(conjOrNot)) > 0.5:
if repair is False:
warnings.warn('YOU MAY NEED TO CONJUGATE YOUR BASIS!!!')
return -1
......@@ -492,7 +491,7 @@ class MRS(object):
basisFIDs, names = getMMBasis(self, lw=lw, shift=True)
for basis, n in zip(basisFIDs, names):
self.basis = np.append(self.basis, basis[:, np.newaxis], axis=1)
self.names.append('MM_'+n)
self.names.append('MM_' + n)
self.numBasis += 1
def set_FID(self, FID):
......@@ -524,6 +523,15 @@ class MRS(object):
def set_basis(self, basis, names, basis_hdr):
''' Set basis in MRS class object '''
if basis is not None:
# Check for duplicate names
for name in names:
dupes = [idx for idx, n in enumerate(names) if n == name]
if len(dupes) > 1:
for idx, ddx in enumerate(dupes[1:]):
names[ddx] = names[ddx] + f'_{idx+1}'
print(f'Found duplicate basis name "{name}", renaming to "{names[ddx]}".')
self.basis = basis.copy()
# Handle single basis spectra
if self.basis.ndim == 1:
......@@ -538,7 +546,7 @@ class MRS(object):
if (names is not None) and (basis_hdr is not None):
self.names = names.copy()
self.set_acquisition_params_basis(1/basis_hdr['bandwidth'])
self.set_acquisition_params_basis(1 / basis_hdr['bandwidth'])
else:
raise ValueError('Pass basis names and header with basis.')
......
......@@ -5,7 +5,7 @@
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# Will Clarke <william.clarke@ndcn.ox.ac.uk>
#
# Copyright (C) 2020 University of Oxford
# Copyright (C) 2020 University of Oxford
# SHBASECOPYRIGHT
import numpy as np
......@@ -13,42 +13,49 @@ from fsl_mrs.core import MRS
from fsl_mrs.utils import mrs_io, misc
import matplotlib.pyplot as plt
import nibabel as nib
from fsl_mrs.utils.mrs_io.fsl_io import saveNIFTI
from fsl_mrs.utils.mrs_io.fsl_io import saveNIFTI, readNIFTI
class MRSI(object):
def __init__(self,FID,header,mask=None,basis=None,names=None,basis_hdr=None,H2O=None):
def __init__(self, FID, header,
mask=None, basis=None, names=None,
basis_hdr=None, H2O=None):
# process mask
if mask is None:
mask = np.full(FID.shape,True)
elif mask.shape[0:3]==FID.shape[0:3]:
mask = mask!=0.0
mask = np.full(FID.shape, True)
elif mask.shape[0:3] == FID.shape[0:3]:
mask = mask != 0.0
else:
raise ValueError(f'Mask must be None or numpy array of the same shape as FID. Mask {mask.shape[0:3]}, FID {FID.shape[0:3]}.')
raise ValueError(f'Mask must be None or numpy'
f' array of the same shape'
f' as FID. Mask {mask.shape[0:3]},'
f' FID {FID.shape[0:3]}.')
# process H2O
if H2O is None:
H2O = np.full(FID.shape,None)
elif H2O.shape!=FID.shape:
raise ValueError('H2O must be None or numpy array of the same shape as FID.')
H2O = np.full(FID.shape, None)
elif H2O.shape != FID.shape:
raise ValueError('H2O must be None or numpy array '
'of the same shape as FID.')
# Load into properties
self.data = FID
self.H2O = H2O
self.mask = mask
self.data = FID
self.H2O = H2O
self.mask = mask
self.header = header
# Basis
self.basis = basis
self.names = names
self.basis_hdr = basis_hdr
# Basis
self.basis = basis
self.names = names
self.basis_hdr = basis_hdr
# tissue segmentation
self.csf = None
self.wm = None
self.gm = None
self.tissue_seg_loaded = False
self.csf = None
self.wm = None
self.gm = None
self.tissue_seg_loaded = False
# Helpful properties
self.spatial_shape = self.data.shape[:3]
......@@ -69,31 +76,33 @@ class MRSI(object):
self.ind_scaling = None
self._store_scalings = None
def __iter__(self):
shape = self.data.shape
self._store_scalings = []
for idx in np.ndindex(shape[:3]):
if self.mask[idx]:
mrs_out = MRS(FID=self.data[idx],
header=self.header,
basis=self.basis,
names=self.names,
basis_hdr=self.basis_hdr,
H2O=self.H2O[idx])
header=self.header,
basis=self.basis,
names=self.names,
basis_hdr=self.basis_hdr,
H2O=self.H2O[idx])
self._process_mrs(mrs_out)
self._store_scalings.append(mrs_out.scaling)
if self.tissue_seg_loaded:
tissue_seg = {'CSF':self.csf[idx],'WM':self.wm[idx],'GM':self.gm[idx]}
tissue_seg = {'CSF': self.csf[idx],
'WM': self.wm[idx],
'GM': self.gm[idx]}
else:
tissue_seg = None
yield mrs_out,idx,tissue_seg
def get_indicies_in_order(self,mask=True):
"""Return a list of iteration indicies in order"""
yield mrs_out, idx, tissue_seg
def get_indicies_in_order(self, mask=True):
"""Return a list of iteration indicies in order"""
out = []
shape = self.data.shape
for idx in np.ndindex(shape[:3]):
......@@ -104,30 +113,34 @@ class MRSI(object):
out.append(idx)
return out
def get_scalings_in_order(self,mask=True):
"""Return a list of MRS object scalings in order"""
def get_scalings_in_order(self, mask=True):
"""Return a list of MRS object scalings in order"""
if self._store_scalings is None:
raise ValueError('Fetch mrs by iterable first.')
else:
return self._store_scalings
def mrs_by_index(self,index):
mrs_out = MRS(FID=self.data[index[0],index[1],index[2],:],
header=self.header,
basis=self.basis,
names=self.names,
basis_hdr=self.basis_hdr,
H2O=self.H2O[index[0],index[1],index[2],:])
def mrs_by_index(self, index):
''' Return MRS object by index (tuple - x,y,z).'''
mrs_out = MRS(FID=self.data[index[0], index[1], index[2], :],
header=self.header,
basis=self.basis,
names=self.names,
basis_hdr=self.basis_hdr,
H2O=self.H2O[index[0], index[1], index[2], :])
self._process_mrs(mrs_out)
return mrs_out
def mrs_from_average(self):
FID = misc.volume_to_list(self.data,self.mask)
H2O = misc.volume_to_list(self.H2O,self.mask)
FID = sum(FID)/len(FID)
H2O = sum(H2O)/len(H2O)
'''
Return average of all masked voxels
as a single MRS object.
'''
FID = misc.volume_to_list(self.data, self.mask)
H2O = misc.volume_to_list(self.H2O, self.mask)
FID = sum(FID) / len(FID)
H2O = sum(H2O) / len(H2O)
mrs_out = MRS(FID=FID,
header=self.header,
basis=self.basis,
......@@ -137,15 +150,20 @@ class MRSI(object):
self._process_mrs(mrs_out)
return mrs_out
def seg_by_index(self,index):
def seg_by_index(self, index):
'''Return segmentation information by index.'''
if self.tissue_seg_loaded:
return {'CSF':self.csf[index],'WM':self.wm[index],'GM':self.gm[index]}
return {'CSF': self.csf[index],
'WM': self.wm[index],
'GM': self.gm[index]}
else:
raise ValueError('Load tissue segmentation first.')
def _process_mrs(self,mrs):
def _process_mrs(self, mrs):
''' Process (conjugate, rescale)
basis and FID and apply basis operations
to all voxels.
'''
if self.basis is not None:
if self.conj_basis:
mrs.conj_Basis()
......@@ -153,7 +171,7 @@ class MRSI(object):
pass
else:
mrs.check_Basis(repair=True)
mrs.keep(self.keep)
mrs.ignore(self.ignore)
......@@ -165,124 +183,147 @@ class MRSI(object):
mrs.check_FID(repair=True)
if self.rescale:
mrs.rescaleForFitting(ind_scaling=self.ind_scaling)
def plot(self,mask=True,ppmlim=(0.2,4.2)):
mrs.rescaleForFitting(ind_scaling=self.ind_scaling)
def plot(self, mask=True, ppmlim=(0.2, 4.2)):
'''Plot (masked) grid of spectra.'''
if mask:
mask_indicies = np.where(self.mask)
else:
mask_indicies = np.where(np.full(self.mask.shape,True))
dim1 = np.asarray((np.min(mask_indicies[0]),np.max(mask_indicies[0])))
dim2 = np.asarray((np.min(mask_indicies[1]),np.max(mask_indicies[1])))
dim3 = np.asarray((np.min(mask_indicies[2]),np.max(mask_indicies[2])))
mask_indicies = np.where(np.full(self.mask.shape, True))
dim1 = np.asarray((np.min(mask_indicies[0]), np.max(mask_indicies[0])))
dim2 = np.asarray((np.min(mask_indicies[1]), np.max(mask_indicies[1])))
dim3 = np.asarray((np.min(mask_indicies[2]), np.max(mask_indicies[2])))
size1 = 1+ dim1[1]-dim1[0]
size2 = 1+ dim2[1]-dim2[0]
size3 = 1+ dim3[1]-dim3[0]
size1 = 1 + dim1[1] - dim1[0]
size2 = 1 + dim2[1] - dim2[0]
size3 = 1 + dim3[1] - dim3[0]
ar1 = size1/(size1+size2)
ar2 = size2/(size1+size2)
ar1 = size1 / (size1 + size2)
ar2 = size2 / (size1 + size2)
for sDx in range(size3):
fig,axes = plt.subplots(size1,size2,figsize=(20*ar2,20*ar1))
for i,j,k in zip(*mask_indicies):
if (not self.mask[i,j,k]) and mask:
fig, axes = plt.subplots(size1, size2, figsize=(20 * ar2, 20 * ar1))
for i, j, k in zip(*mask_indicies):
if (not self.mask[i, j, k]) and mask:
continue
ii = i - dim1[0]
jj = j - dim2[0]
ax = axes[ii,jj]
mrs = self.mrs_by_index([i,j,k])
ax.plot(mrs.getAxes(ppmlim=ppmlim),np.real(mrs.get_spec(ppmlim=ppmlim)))
ax = axes[ii, jj]
mrs = self.mrs_by_index([i, j, k])
ax.plot(mrs.getAxes(ppmlim=ppmlim), np.real(mrs.get_spec(ppmlim=ppmlim)))
ax.invert_xaxis()
ax.set_xticks([])
ax.set_yticks([])
plt.subplots_adjust(left = 0.03, # the left side of the subplots of the figure
right = 0.97, # the right side of the subplots of the figure
bottom = 0.01, # the bottom of the subplots of the figure
top = 0.95, # the top of the subplots of the figure
wspace = 0, # the amount of width reserved for space between subplots,
hspace = 0)
fig.suptitle(f'Slice {k}')
plt.show()
plt.subplots_adjust(left=0.03, # the left side of the subplots of the figure
right=0.97, # the right side of the subplots of the figure
bottom=0.01, # the bottom of the subplots of the figure
top=0.95, # the top of the subplots of the figure
wspace=0, # the amount of width reserved for space between subplots,
hspace=0)
fig.suptitle(f'Slice {sDx}')
plt.show()
def __str__(self):
return f'MRSI with shape {self.data.shape}\nNumber of voxels = {self.num_voxels}\nNumber of masked voxels = {self.num_masked_voxels}'
return f'MRSI with shape {self.data.shape}\n' \
f'Number of voxels = {self.num_voxels}\n' \
f'Number of masked voxels = {self.num_masked_voxels}'
def __repr__(self):
return str(self)
def set_mask(self,mask):
def set_mask(self, mask):
""" Load mask as numpy array."""
if mask is None:
mask = np.full(self.data.shape,True)
elif mask.shape[0:3]==self.data.shape[0:3]:
mask = mask!=0.0
mask = np.full(self.data.shape, True)
elif mask.shape[0:3] == self.data.shape[0:3]:
mask = mask != 0.0
else:
raise ValueError(f'Mask must be None or numpy array of the same shape as FID. Mask {mask.shape[0:3]}, FID {self.data.shape[0:3]}.')
raise ValueError(f'Mask must be None or numpy array of the same shape as FID.'
f' Mask {mask.shape[0:3]}, FID {self.data.shape[0:3]}.')
self.mask = mask
self.num_masked_voxels = np.sum(self.mask)
def set_tissue_seg(self,csf,wm,gm):
def set_tissue_seg(self, csf, wm, gm):
""" Load tissue segmentation as numpy arrays."""
if (csf.shape != self.spatial_shape) or (wm.shape != self.spatial_shape) or (gm.shape != self.spatial_shape):
raise ValueError(f'Tissue segmentation arrays have wrong shape (CSF:{csf.shape}, GM:{gm.shape}, WM:{wm.shape}). Must match FID ({self.spatial_shape}).')
raise ValueError(f'Tissue segmentation arrays have wrong shape '
f'(CSF:{csf.shape}, GM:{gm.shape}, WM:{wm.shape}).'
f' Must match FID ({self.spatial_shape}).')
self.csf = csf
self.wm = wm
self.gm = gm
self.tissue_seg_loaded = True
def write_output(self,data_list,file_path_name,indicies=None,cleanup=True,dtype=float):
if indicies==None:
def write_output(self, data_list, file_path_name, indicies=None, cleanup=True, dtype=float):
'''Write 3D or 4D array of data to nifti file with current orientation.'''
if indicies is None:
indicies = self.get_indicies_in_order()
nt = data_list[0].size
if nt>1:
data = np.zeros(self.spatial_shape+(nt,),dtype=dtype)
nt = data_list[0].size
if nt > 1:
data = np.zeros(self.spatial_shape + (nt,), dtype=dtype)
else:
data = np.zeros(self.spatial_shape,dtype=dtype)
data = np.zeros(self.spatial_shape, dtype=dtype)
for d,ind in zip(data_list,indicies):
for d, ind in zip(data_list, indicies):
data[ind] = d
if cleanup:
data[np.isnan(data)] = 0
data[np.isinf(data)] = 0
data[data<1e-10] = 0
data[data>1e10] = 0
data[data < 1e-10] = 0
data[data > 1e10] = 0
if nt == self.FID_points:
saveNIFTI(file_path_name, data, self.header)
else:
img = nib.Nifti1Image(data,self.header['nifti'].affine)
else:
img = nib.Nifti1Image(data, self.header['nifti'].affine)
nib.save(img, file_path_name)
@classmethod
def from_files(cls,data_file,mask_file=None,basis_file=None,H2O_file=None,csf_file=None,gm_file=None,wm_file=None):
data,hdr = mrs_io.read_FID(data_file)
def from_files(cls, data_file,
mask_file=None,
basis_file=None,
H2O_file=None,
csf_file=None,
gm_file=None,
wm_file=None):
""" Load MRSI data directly from files """
data, hdr = mrs_io.read_FID(data_file)
if mask_file is not None:
mask,_ = mrs_io.fsl_io.readNIFTI(mask_file)
mask, _ = readNIFTI(mask_file)
else:
mask = None
if basis_file is not None:
basis,names,basisHdr = mrs_io.read_basis(basis_file)
basis, names, basisHdr = mrs_io.read_basis(basis_file)
else:
basis,names,basisHdr = None,None,[None,]
basis, names, basisHdr = None, None, [None, ]
if H2O_file is not None:
data_w,hdr_w = mrs_io.read_FID(H2O_file)
data_w, hdr_w = mrs_io.read_FID(H2O_file)
else:
data_w = None
out = cls(data,hdr,mask=mask,basis=basis,names=names,basis_hdr=basisHdr[0],H2O=data_w)
out = cls(data, hdr,
mask=mask,
basis=basis,
names=names,
basis_hdr=basisHdr[0],
H2O=data_w)
def loadNii(f):
nii = np.asanyarray(nib.load(f).dataobj)
if nii.ndim == 2:
nii = np.expand_dims(nii, 2)
return nii
if (csf_file is not None) and (gm_file is not None) and (wm_file is not None):
csf,_ = mrs_io.fsl_io.readNIFTI(csf_file)
gm,_ = mrs_io.fsl_io.readNIFTI(gm_file)
wm,_ = mrs_io.fsl_io.readNIFTI(wm_file)
out.set_tissue_seg(csf,wm,gm)
csf = loadNii(csf_file)
gm = loadNii(gm_file)
wm = loadNii(wm_file)
out.set_tissue_seg(csf, wm, gm)
return out
from fsl_mrs.core.MRS import MRS
from fsl_mrs.core.MRSI import MRSI
\ No newline at end of file
from fsl_mrs.core.MRSI import MRSI
......@@ -37,13 +37,13 @@ def synth_data():
timeAxis = np.linspace(hdr['dwelltime'],
hdr['dwelltime'] * 2048,
2048)
frequencyAxis = np.linspace(-hdr['bandwidth']/2,
hdr['bandwidth']/2,
frequencyAxis = np.linspace(-hdr['bandwidth'] / 2,
hdr['bandwidth'] / 2,
2048)
ppmAxis = hz2ppm(hdr['centralFrequency']*1E6,
ppmAxis = hz2ppm(hdr['centralFrequency'] * 1E6,
frequencyAxis,
shift=False)
ppmAxisShift = hz2ppm(hdr['centralFrequency']*1E6,
ppmAxisShift = hz2ppm(hdr['centralFrequency'] * 1E6,