Commit ceb31132 authored by William Clarke's avatar William Clarke
Browse files

MRSI update. Need to check new core modules.

parent 54de824d
......@@ -90,6 +90,7 @@ class MRS(object):
# Other properties
self.metab_groups = None
self.scaling = {'FID':1.0,'basis':1.0}
def from_files(self,FID_file,Basis_file):
......@@ -261,7 +262,12 @@ class MRS(object):
self.H2O *= scaling
if self.basis is not None:
self.basis,_ = misc.rescale_FID(self.basis,scale=scale)
self.basis,scaling_basis = misc.rescale_FID(self.basis,scale=scale)
else:
scaling_basis = None
self.scaling = {'FID':scaling,'basis':scaling_basis}
def check_FID(self,ppmlim=(.2,4.2),repair=False):
"""
......@@ -356,9 +362,10 @@ class MRS(object):
if metabs is not None:
for m in metabs:
idx = self.names.index(m)
self.names.pop(idx)
self.basis = np.delete(self.basis,idx,axis=1)
names = np.asarray(self.names)
index = names ==m
self.names = names[~index].tolist()
self.basis = self.basis[:,~index]
self.numBasis = len(self.names)
def keep(self,metabs):
......@@ -376,33 +383,41 @@ class MRS(object):
self.ignore(metabs)
def add_peak(self,ppm,name,gamma=0,sigma=0):
def add_peak(self,ppm,amp,name,gamma=0,sigma=0):
"""
Add peak to basis
"""
peak = misc.create_peak(self,ppm,gamma,sigma)[:,None]
peak = misc.create_peak(self,ppm,amp,gamma,sigma)[:,None]
self.basis = np.append(self.basis,peak,axis=1)
self.names.append(name)
self.numBasis += 1
def add_MM_peaks(self,ppmlist=None,gamma=0,sigma=0):
def add_MM_peaks(self,ppmlist=None,amplist=None,gamma=0,sigma=0):
"""
Add macromolecule list
Parameters
----------
ppmlist : default is [1.7,1.4,1.2,2.0,0.9]
ppmlist : default is [0.9,1.7,1.4,1.2,2.0]
gamma,sigma : float parameters of Voigt blurring
"""
if ppmlist is None:
ppmlist = [1.7,1.4,1.2,2.0,0.9]
names = ['MM'+'{:.0f}'.format(i*10).zfill(2) for i in ppmlist]
ppmlist = [0.9,1.2,1.4,1.7,[2.08,2.25,1.95,3.0]]
amplist = [3.0,2.0,2.0,2.0,[1.33,0.33,0.33,0.4]]
for idx,_ in enumerate(ppmlist):
if isinstance(ppmlist[idx],(float,int)):
ppmlist[idx] = [float(ppmlist[idx]),]
if isinstance(amplist[idx],(float,int)):
amplist[idx] = [float(amplist[idx]),]
names = [f'MM{i[0]*10:02.0f}' for i in ppmlist]
for name,ppm in zip(names,ppmlist):
self.add_peak(ppm,name,gamma,sigma)
for name,ppm,amp in zip(names,ppmlist,amplist):
self.add_peak(ppm,amp,name,gamma,sigma)
return len(ppmlist)
......
......@@ -12,6 +12,9 @@ import numpy as np
from fsl_mrs.core import MRS
from fsl_mrs.utils import mrs_io,plotting,fitting
import matplotlib.pyplot as plt
import nibabel as nib
import os.path as op
from fsl_mrs.utils.mrs_io.fsl_io import saveNIFTI
class MRSI(object):
......@@ -42,6 +45,12 @@ class MRSI(object):
self.names = names
self.basis_hdr = basis_hdr
# tissue segmentation
self.csf = None
self.wm = None
self.gm = None
self.tissue_seg_loaded = False
# Helpful properties
self.spatial_shape = self.data.shape[:3]
self.FID_points = self.data.shape[3]
......@@ -49,9 +58,21 @@ class MRSI(object):
self.num_masked_voxels = np.sum(self.mask)
if self.names is not None:
self.num_basis = len(names)
# MRS output options
self.conj_basis = False
self.no_conj_basis = False
self.conj_FID = False
self.no_conj_FID = False
self.rescale = False
self.keep = None
self.ignore = None
self._store_scalings = None
def __iter__(self):
shape = self.data.shape
self._store_scalings = []
for idx in np.ndindex(shape[:3]):
if self.mask[idx]:
mrs_out = MRS(FID=self.data[idx],
......@@ -60,20 +81,74 @@ class MRSI(object):
names=self.names,
basis_hdr=self.basis_hdr,
H2O=self.H2O[idx])
mrs_out.check_FID(repair=True)
mrs_out.check_Basis(repair=True)
yield mrs_out,idx
self._process_mrs(mrs_out)
self._store_scalings.append(mrs_out.scaling)
if self.tissue_seg_loaded:
tissue_seg = [self.csf[idx],self.wm[idx],self.gm[idx]]
else:
tissue_seg = None
yield mrs_out,idx,tissue_seg
def mrsByIndex(self,index):
def get_indicies_in_order(self,mask=True):
"""Return a list of iteration indicies in order"""
out = []
shape = self.data.shape
for idx in np.ndindex(shape[:3]):
if mask:
if self.mask[idx]:
out.append(idx)
else:
out.append(idx)
return out
def get_scalings_in_order(self,mask=True):
"""Return a list of MRS object scalings in order"""
if self._store_scalings is None:
raise ValueError('Fetch mrs by iterable first.')
else:
return self._store_scalings
def mrs_by_index(self,index):
mrs_out = MRS(FID=self.data[index[0],index[1],index[2],:],
header=self.header,
basis=self.basis,
names=self.names,
basis_hdr=self.basis_hdr,
H2O=self.H2O[index[0],index[1],index[2],:])
mrs_out.check_FID(repair=True)
mrs_out.check_Basis(repair=True)
self._process_mrs(mrs_out)
return mrs_out
def seg_by_index(self,index):
if self.tissue_seg_loaded:
return [self.csf[index],self.wm[index],self.gm[index]]
else:
raise ValueError('Load tissue segmentation first.')
def _process_mrs(self,mrs):
if self.basis is not None:
if self.conj_basis:
mrs.conj_Basis()
elif self.no_conj_basis:
pass
else:
mrs.check_Basis(repair=True)
mrs.keep(self.keep)
mrs.ignore(self.ignore)
if self.conj_FID:
mrs.conj_FID()
elif self.no_conj_FID:
pass
else:
mrs.check_FID(repair=True)
if self.rescale:
mrs.rescaleForFitting()
def plot(self,mask=True,ppmlim=(0.2,4.2)):
if mask:
......@@ -99,7 +174,7 @@ class MRSI(object):
ii = i - dim1[0]
jj = j - dim2[0]
ax = axes[ii,jj]
mrs = self.mrsByIndex([i,j,k])
mrs = self.mrs_by_index([i,j,k])
ax.plot(mrs.getAxes(ppmlim=ppmlim),np.real(mrs.getSpectrum(ppmlim=ppmlim)))
ax.invert_xaxis()
ax.set_xticks([])
......@@ -110,9 +185,87 @@ class MRSI(object):
top = 0.95, # the top of the subplots of the figure
wspace = 0, # the amount of width reserved for space between subplots,
hspace = 0)
fig.suptitle(f'Slice {k}')
fig.suptitle(f'Slice {k}')
plt.show()
def __str__(self):
return f'MRSI with shape {self.data.shape}\nNumber of voxels = {self.num_voxels}\nNumber of masked voxels = {self.num_masked_voxels}'
def __repr__(self):
return str(self)
\ No newline at end of file
return str(self)
def set_mask(self,mask):
""" Load mask as numpy array."""
if mask is None:
mask = np.full(self.data.shape,True)
elif mask.shape[0:3]==self.data.shape[0:3]:
mask = mask!=0.0
else:
raise ValueError(f'Mask must be None or numpy array of the same shape as FID. Mask {mask.shape[0:3]}, FID {self.data.shape[0:3]}.')
self.mask = mask
self.num_masked_voxels = np.sum(self.mask)
def set_tissue_seg(self,csf,wm,gm):
""" Load tissue segmentation as numpy arrays."""
if (csf.shape != self.spatial_shape) or (wm.shape != self.spatial_shape) or (gm.shape != self.spatial_shape):
raise ValueError(f'Tissue segmentation arrays have wrong shape (CSF:{csf.shape}, GM:{gm.shape}, WM:{wm.shape}). Must match FID ({self.spatial_shape}).')
self.csf = csf
self.wm = wm
self.gm = gm
self.tissue_seg_loaded = True
def write_output(self,data_list,file_path_name,indicies=None,cleanup=True,dtype=float):
if indicies==None:
indicies = self.get_indicies_in_order()
nt = data_list[0].size
if nt>1:
data = np.zeros(self.spatial_shape+(nt,),dtype=dtype)
else:
data = np.zeros(self.spatial_shape,dtype=dtype)
for d,ind in zip(data_list,indicies):
data[ind] = d
if cleanup:
data[np.isnan(data)] = 0
data[np.isinf(data)] = 0
data[data<1e-10] = 0
data[data>1e10] = 0
if nt == self.FID_points:
saveNIFTI(file_path_name, data, self.header)
else:
img = nib.Nifti1Image(data,self.header['nifti'].affine)
nib.save(img, file_path_name)
@classmethod
def from_files(cls,data_file,mask_file=None,basis_file=None,H2O_file=None,csf_file=None,gm_file=None,wm_file=None):
data,hdr = mrs_io.read_FID(data_file)
if mask_file is not None:
mask,_ = mrs_io.fsl_io.readNIFTI(mask_file)
else:
mask = None
if basis_file is not None:
basis,names,basisHdr = mrs_io.read_basis(basis_file)
else:
basis,names,basisHdr = None,None,[None,]
if H2O_file is not None:
data_w,hdr_w = mrs_io.read_FID(H2O_file)
else:
data_w = None
out = cls(data,hdr,mask=mask,basis=basis,names=names,basis_hdr=basisHdr[0],H2O=data_w)
if (csf_file is not None) and (gm_file is not None) and (wm_file is not None):
csf,_ = mrs_io.fsl_io.readNIFTI(csf_file)
gm,_ = mrs_io.fsl_io.readNIFTI(gm_file)
wm,_ = mrs_io.fsl_io.readNIFTI(wm_file)
out.set_tissue_seg(csf,wm,gm)
return out
\ No newline at end of file
......@@ -77,7 +77,7 @@ def main():
help='structural image (for report)')
optional.add_argument('--TE',type=float,default=None,metavar='TE',
help='Echo time for relaxation correction (ms)')
optional.add_argument('--tissue_frac',type=tissue_frac_arg,action=TissueFracAction,nargs='+',default=None,metavar='GM WM CSF OR json',
optional.add_argument('--tissue_frac',type=tissue_frac_arg,action=TissueFracAction,nargs='+',default=None,metavar='WM GM CSF OR json',
help='Fractional tissue volumes for WM, GM, CSF or json segmentation file. Defaults to pure water scaling.')
optional.add_argument('--internal_ref',type=str,default=['Cr','PCr'],nargs='+',
help='Metabolite(s) used as an internal reference. Defaults to tCr (Cr+PCr).')
......
This diff is collapsed.
......@@ -17,16 +17,18 @@ def main():
help='NIfTI file or directory of basis sets')
p.add_argument('--ppmlim',default=(.2,4.2),type=float,nargs=2,metavar=('LOW','HIGH'),
help='limit the fit to a freq range (default=(.2,4.2))')
p.add_argument('--mask',default=None,type=str,help='Mask for MRSI')
args = p.parse_args()
from fsl_mrs.utils.plotting import plot_spectrum,FID2Spec,plot_spectra
from fsl_mrs.utils.mrs_io import read_FID,read_basis
import matplotlib.pyplot as plt
import os.path as op
from fsl_mrs.core import MRS
from fsl_mrs.core import MRS,MRSI
import numpy as np
from fsl_mrs.utils.preproc.combine import combine_FIDs
import nibabel as nib
# breakpoint()
if op.isdir(args.file) or op.splitext(op.basename(args.file))[1].upper()=='.BASIS':
......@@ -44,23 +46,36 @@ def main():
plt.legend()
plt.show()
elif op.isfile(args.file):
fid,header = read_FID(args.file)
if fid.ndim==1:
mrs = MRS(FID=fid,header=header)
mrs.check_FID(repair=True)
fig = plot_spectrum(mrs,ppmlim=args.ppmlim)
plt.show()
fid,header = read_FID(args.file,squeezeSVS=False)
if np.prod(fid.shape[:3])==1:
# SVS
fid = np.squeeze(fid)
if fid.ndim==1:
mrs = MRS(FID=fid,header=header)
mrs.check_FID(repair=True)
fig = plot_spectrum(mrs,ppmlim=args.ppmlim)
plt.show()
else:
mrs_list = []
for f in fid.T:
tmpmrs = MRS(FID=f,header=header)
tmpmrs.check_FID(repair=True)
mrs_list.append(tmpmrs)
fidList = [m.FID for m in mrs_list]
combfid = combine_FIDs(fidList,'svd',do_prewhiten=True)
single_mrs = MRS(FID=combfid,header=header)
fig = plot_spectra(mrs_list,ppmlim=args.ppmlim,single_FID=single_mrs)
plt.show()
else:
mrs_list = []
for f in fid.T:
tmpmrs = MRS(FID=f,header=header)
tmpmrs.check_FID(repair=True)
mrs_list.append(tmpmrs)
fidList = [m.FID for m in mrs_list]
combfid = combine_FIDs(fidList,'svd',do_prewhiten=True)
single_mrs = MRS(FID=combfid,header=header)
fig = plot_spectra(mrs_list,ppmlim=args.ppmlim,single_FID=single_mrs)
plt.show()
if args.mask is not None:
mask_hdr = nib.load(args.mask)
mask = np.asanyarray(mask_hdr.dataobj)
if mask.ndim == 2:
mask = np.expand_dims(mask,2)
mrsi = MRSI(fid,header,mask = mask)
mrsi.plot()
if __name__ == '__main__':
main()
#!/usr/bin/env python
# mrsi_segment - use fsl to segmenta T1 and register it to an mrsi scan
#
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# William Clarke <william.clarke@ndcn.ox.ac.uk>
#
# Copyright (C) 2020 University of Oxford
# SHBASECOPYRIGHT
# Quick imports
import argparse
import os.path as op
from os import remove
import nibabel as nib
import numpy as np
from fsl.wrappers import flirt,fsl_anat,fslroi
from subprocess import call
import json
import warnings
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(description="FSL Magnetic Resonance Spectroscopy - register fast segmentation to mrsi.")
parser.add_argument('mrsi', type=str,metavar='MRSI',
help='MRSI nifti file')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-t','--t1', type=str,metavar='T1',help='T1 nifti file')
group.add_argument('-a','--anat', type=str,help='FSL anat directory for tissue segmentation output.')
parser.add_argument('-o','--output', type=str,help='Output directory',default='.')
parser.add_argument('-f','--filename', type=str,help='Output file name',default='mrsi_seg')
args = parser.parse_args()
# If not prevented run fsl_anat for fast segmentation
if (args.anat is None) and (not args.mask_only):
anat = op.join(args.output,'fsl_anat')
fsl_anat(args.t1,out = anat, nosubcortseg=True)
anat += '.anat'
else:
anat = args.anat
# Make dummy nifti as nothing works with complex data
call(['fslcomplex','-realabs',args.mrsi,op.join(args.output,'tmp.nii.gz')])
call(['fslcpgeom',args.mrsi,op.join(args.output,'tmp.nii.gz')])
fslroi(op.join(args.output,'tmp.nii.gz'),op.join(args.output,'tmp.nii.gz'),0,1)
# Register the pvseg to the MRSI data using flirt
flirt_func = lambda i,o : flirt(i,op.join(args.output,'tmp.nii.gz'),out=o,
usesqform=True,
applyxfm=True,
noresampblur=True,
interp='nearestneighbour',
setbackground=0,
paddingsize=1)
# T1_fast_pve_0, T1_fast_pve_1, T1_fast_pve_2 - partial volume segmentations (CSF, GM, WM respectively)
flirt_func(op.join(anat,'T1_fast_pve_0.nii.gz'),op.join(args.output,args.filename+'_csf.nii.gz'))
flirt_func(op.join(anat,'T1_fast_pve_1.nii.gz'),op.join(args.output,args.filename+'_gm.nii.gz'))
flirt_func(op.join(anat,'T1_fast_pve_2.nii.gz'),op.join(args.output,args.filename+'_wm.nii.gz'))
remove(op.join(args.output,'tmp.nii.gz'))
if __name__ == '__main__':
main()
\ No newline at end of file
#!/usr/bin/env python
# make_mrs_mask - use fsl to make a mask from a svs voxel and T1 nifti
# svs_segment - use fsl to make a mask from a svs voxel and T1 nifti, then produce tissue segmentation file.
#
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# William Clarke <william.clarke@ndcn.ox.ac.uk>
......@@ -20,7 +20,7 @@ import warnings
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(description="FSL Magnetic Resonance Spectroscopy - Merge HTML reports based on filename in directory.")
parser = argparse.ArgumentParser(description="FSL Magnetic Resonance Spectroscopy - Construct mask in T1 space of an SVS voxel and generate a tissue segmentation file.")
parser.add_argument('svs', type=str,metavar='SVS',
help='SVS nifti file')
......@@ -28,7 +28,7 @@ def main():
group.add_argument('-t','--t1', type=str,metavar='T1',help='T1 nifti file')
group.add_argument('-a','--anat', type=str,help='FSL anat directory for tissue segmentation output.')
parser.add_argument('-o','--output', type=str,help='Output directory',default='.')
parser.add_argument('-f','--filename', type=str,help='Output file name directory',default='mask.nii.gz')
parser.add_argument('-f','--filename', type=str,help='Output file name',default='mask.nii.gz')
parser.add_argument('-m','--mask_only', action="store_true",help='Only perform masking stage, do not run fsl_anat if only T1 passed.')
args = parser.parse_args()
......
......@@ -19,7 +19,7 @@ def test_fsl_mrs(tmp_path):
'--output', tmp_path,
'--h2o', data['water'],
'--TE', '11',
'--tissue_fractions', '0.45', '0.45', '0.1',
'--tissue_frac', '0.45', '0.45', '0.1',
'--overwrite',
'--combine', 'Cr', 'PCr',
'--report'])
......
......@@ -457,14 +457,15 @@ def rescale_FID(x,scale=100):
return y,1/factor * scale
def create_peak(mrs,ppm,gamma=0,sigma=0):
def create_peak(mrs,ppm,amp,gamma=0,sigma=0):
"""
creates FID for peak at specific ppm
Parameters
----------
mrs : MRS object (contains time information)
ppm : float
ppm : list of floats
amp : list of floats
gamma : float
Peak Lorentzian dispersion
sigma : float
......@@ -475,17 +476,27 @@ def create_peak(mrs,ppm,gamma=0,sigma=0):
array-like FID
"""
freq = ppm2hz(mrs.centralFrequency,ppm)
t = mrs.timeAxis
x = np.exp(1j*2*np.pi*freq*t).flatten()
if gamma>0 or sigma>0:
x = blur_FID_Voigt(mrs,x,gamma,sigma)
if isinstance(ppm,(float,int)):
ppm = [float(ppm),]
if isinstance(amp,(float,int)):
amp = [float(amp),]
t = mrs.timeAxis
out = np.zeros(t.shape[0],dtype=np.complex128)
for p,a in zip(ppm,amp):
freq = ppm2hz(mrs.centralFrequency,p)
x = a*np.exp(1j*2*np.pi*freq*t).flatten()
if gamma>0 or sigma>0:
x = blur_FID_Voigt(mrs,x,gamma,sigma)
# dephase
x = x*np.exp(-1j*np.angle(x[0]))
return x
# dephase
x = x*np.exp(-1j*np.angle(x[0]))
out+= x
return out
def extract_spectrum(mrs,FID,ppmlim=(0.2,4.2),shift=True):
"""
......
......@@ -459,8 +459,10 @@ def plotly_fit(mrs,res,ppmlim=None,proj='real',metabs = None,phs=(0,0)):
resid = project(resid,proj)
# y-axis range
ymin = np.min(data)-np.min(data)/10
ymax = np.max(data)-np.max(data)/30
minval = min(np.min(base),np.min(data),np.min(preds),np.min(resid))
maxval = max(np.max(base),np.max(data),np.max(preds),np.max(resid))
ymin = minval-minval/2
ymax = maxval+maxval/30
# Build the plot
......@@ -585,7 +587,7 @@ def plot_mcmc_corr(res,corr=None):
corrabs = np.abs(corr)
fig.add_trace(go.Heatmap(z=corr,
x=res.original_metabs,y=res.original_metabs,colorscale='Picnic'))
x=res.original_metabs,y=res.original_metabs,colorscale='Picnic',zmid=0))
fig.update_layout(template = 'plotly_white',
font=dict(size=10),
......
......@@ -7,6 +7,8 @@ from numpy.lib.stride_tricks import as_strided
from collections import namedtuple
import pandas as pd
SNR = namedtuple('SNR',['spectrum','peaks','residual'])
class NoiseNotFoundError(ValueError):
pass
......@@ -51,8 +53,7 @@ def calcQC(mrs,res,ppmlim=(0.2,4.2)):
snrResidual = snrResidual_height/rmse
# Assemble outputs
# SNR output
SNR = namedtuple('SNR',['spectrum','peaks','residual'])
# SNR output
snrdf = pd.DataFrame()
for m,snr in zip(res.metabs,snrPeaks):
snrdf[f'SNR_{m}'] = pd.Series(snr)
......
......@@ -431,24 +431,32 @@ class FitRes(object):