Commit f692164b authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

general updates

parent 6602fc20
# FSL-MRS # FSL-MRS
### Installation instructions
``` ### Description
git clone https://git.fmrib.ox.ac.uk/saad/fsl_mrs.git
cd fsl_mrs
pip install .
```
FSL-MRS is a collection of python modules and wrapper scripts for pre-processing and model fitting of Magnetic Resonance Spectroscopy (MRS) data.
---
### Installation
git clone https://git.fmrib.ox.ac.uk/saad/fsl_mrs.git
cd fsl_mrs
pip install .
---
### Content ### Content
#### Scripts:
- **fsl\_mrs**
: fit a single spectrum
- **fsl\_mrsi**
: fit a 4D volume of spectra
- **fsl\_mrs\_preproc**
: pre-processing (coil combination, averaging, eddy-current correction)
- **fsl\_mrs\_sim**
: simulate basis
- **mrs_vis**
: quick visualisation of the spectrum
---
#### Usage
For each of the wrapper scripts above, simply type `<name_of_script> --help` to get the usage.
#### File types
FSL-MRS accepts FID data in NIFTI format. It can also read .RAW format (like LCModel).
#### Working in python
If you don't want to use the wrapper scripts, you can use the python modules directly in your own python scripts/programs. Here are some examples below:
- Pre-processing
- Model fitting - single voxel
- Model fitting - MRSImaging
......
...@@ -36,20 +36,23 @@ class MRS(object): ...@@ -36,20 +36,23 @@ class MRS(object):
# (now copying the data - looks ugly but better than referencing. # (now copying the data - looks ugly but better than referencing.
# now I can run multiple times with different setups) # now I can run multiple times with different setups)
self.FID = FID.copy() self.FID = FID.copy()
self.basis = basis.copy() self.basis = basis
if basis is not None:
self.basis = basis.copy()
if H2O is not None: if H2O is not None:
self.H2O = H2O.copy() self.H2O = H2O.copy()
else: else:
self.H2O = None self.H2O = None
self.centralFrequency = cf # Hz self.centralFrequency = cf # Hz
self.bandwidth = bw # Hz self.bandwidth = bw # Hz
self.names = names.copy() self.names = names
if names is not None:
self.names = names.copy()
# Set remaining class attributes # Set remaining class attributes
self.Spec = None self.Spec = None
self.numPoints = None self.numPoints = None
self.Spec = None
self.numBasis = None self.numBasis = None
# Constants # Constants
...@@ -169,13 +172,13 @@ class MRS(object): ...@@ -169,13 +172,13 @@ class MRS(object):
return int(first),int(last) return int(first),int(last)
def resample_basis(self): def resample_basis(self,dwelltime):
""" """
Sometimes the basis is simulated using different timings (dwelltime) Sometimes the basis is simulated using different timings (dwelltime)
This interpolates the basis to match the FID This interpolates the basis to match the FID
""" """
# RESAMPLE BASIS FUNCTION # RESAMPLE BASIS FUNCTION
bdt = self.basis_dwellTime bdt = dwelltime
bbw = 1/bdt bbw = 1/bdt
bn = self.basis.shape[0] bn = self.basis.shape[0]
...@@ -205,10 +208,11 @@ class MRS(object): ...@@ -205,10 +208,11 @@ class MRS(object):
0 if check successful and -1 if not (also issues warning) 0 if check successful and -1 if not (also issues warning)
""" """
Spec1 = np.fft.fft(self.FID)
Spec2 = np.fft.fft(np.conj(self.FID))
first,last = self.ppmlim_to_range(ppmlim) first,last = self.ppmlim_to_range(ppmlim)
if np.linalg.norm(Spec1[first:last]) < np.linalg.norm(Spec2[first:last]): Spec1 = np.real(np.fft.fft(self.FID))[first:last]
Spec2 = np.real(np.fft.fft(np.conj(self.FID)))[first:last]
if np.linalg.norm(misc.detrend(Spec1,deg=4)) < np.linalg.norm(misc.detrend(Spec2,deg=4)):
if repare is False: if repare is False:
warnings.warn('YOU MAY NEED TO CONJUGATE YOU FID!!!') warnings.warn('YOU MAY NEED TO CONJUGATE YOU FID!!!')
return -1 return -1
...@@ -219,6 +223,9 @@ class MRS(object): ...@@ -219,6 +223,9 @@ class MRS(object):
return 0 return 0
def conj_FID(self): def conj_FID(self):
"""
Conjugate FID and recalculate spectrum
"""
self.FID = np.conj(self.FID) self.FID = np.conj(self.FID)
self.Spec = np.fft.fft(self.FID) self.Spec = np.fft.fft(self.FID)
...@@ -254,7 +261,6 @@ class MRS(object): ...@@ -254,7 +261,6 @@ class MRS(object):
""" """
if metabs is not None: if metabs is not None:
metabs = [m for m in self.names if m not in metabs] metabs = [m for m in self.names if m not in metabs]
#metabs = list(set(self.names)-set(metabs))
self.ignore(metabs) self.ignore(metabs)
...@@ -306,7 +312,7 @@ class MRS(object): ...@@ -306,7 +312,7 @@ class MRS(object):
return len(ppmlist) return len(ppmlist)
# I/O functions # I/O functions [NOW OBSOLETE?]
@staticmethod @staticmethod
def read(filename,TYPE='RAW'): def read(filename,TYPE='RAW'):
""" """
......
...@@ -58,7 +58,7 @@ def main(): ...@@ -58,7 +58,7 @@ def main():
help='combine certain metabolites [repeatable]') help='combine certain metabolites [repeatable]')
fitting_args.add_argument('--ppmlim',default=(.2,4.2),type=float,nargs=2,metavar=('LOW','HIGH'), fitting_args.add_argument('--ppmlim',default=(.2,4.2),type=float,nargs=2,metavar=('LOW','HIGH'),
help='limit the fit to a freq range (default=(.2,4.2))') help='limit the fit to a freq range (default=(.2,4.2))')
fitting_args.add_argument('--h2o',type=str,metavar='H2O', fitting_args.add_argument('--h2o',default=None,type=str,metavar='H2O',
help='input .H2O file for quantification') help='input .H2O file for quantification')
fitting_args.add_argument('--baseline_order',default=2,type=int,metavar=('ORDER'), fitting_args.add_argument('--baseline_order',default=2,type=int,metavar=('ORDER'),
help='order of baseline polynomial (default=2)') help='order of baseline polynomial (default=2)')
...@@ -68,8 +68,9 @@ def main(): ...@@ -68,8 +68,9 @@ def main():
help="include default macromolecule peaks") help="include default macromolecule peaks")
# ADDITONAL OPTIONAL ARGUMENTS # ADDITONAL OPTIONAL ARGUMENTS
optional.add_argument('--t1',type=str,default=None,metavar='IMAGE',
help='structural image (for report)')
optional.add_argument('--central_frequency',default=None,type=float, optional.add_argument('--central_frequency',default=None,type=float,
help='central frequency in Hz') help='central frequency in Hz')
optional.add_argument('--dwell_time',default=None,type=float, optional.add_argument('--dwell_time',default=None,type=float,
...@@ -78,6 +79,8 @@ def main(): ...@@ -78,6 +79,8 @@ def main():
help='output html report') help='output html report')
optional.add_argument('--verbose',action="store_true", optional.add_argument('--verbose',action="store_true",
help='spit out verbose info') help='spit out verbose info')
optional.add_argument('--phase_correct',action="store_true",
help='do phase correction')
optional.add_argument('--overwrite',action="store_true", optional.add_argument('--overwrite',action="store_true",
help='overwrite existing output folder') help='overwrite existing output folder')
optional.add('--config', required=False, is_config_file=True, help='configuration file') optional.add('--config', required=False, is_config_file=True, help='configuration file')
...@@ -106,6 +109,8 @@ def main(): ...@@ -106,6 +109,8 @@ def main():
import datetime import datetime
import plotly import plotly
# ###################################################### # ######################################################
if not args.verbose:
warnings.filterwarnings("ignore")
# Check if output folder exists # Check if output folder exists
...@@ -120,9 +125,9 @@ def main(): ...@@ -120,9 +125,9 @@ def main():
exit() exit()
else: else:
shutil.rmtree(args.output) shutil.rmtree(args.output)
os.mkdir(args.output) os.makedirs(args.output,exist_ok=True)
else: else:
os.mkdir(args.output) os.makedirs(args.output,exist_ok=True)
# Save chosen arguments # Save chosen arguments
...@@ -134,7 +139,13 @@ def main(): ...@@ -134,7 +139,13 @@ def main():
####### Do the work ####### ####### Do the work #######
# Read data/h2o/basis # Read data/h2o/basis
if args.verbose:
print('--->> Read input data and basis\n')
print(' {}'.format(args.data))
print(' {}\n'.format(args.basis))
FID,dataheader = mrs_io.read_FID(args.data) FID,dataheader = mrs_io.read_FID(args.data)
basis, names, basisheader = mrs_io.read_basis(args.basis) basis, names, basisheader = mrs_io.read_basis(args.basis)
if args.h2o is not None: if args.h2o is not None:
...@@ -142,27 +153,37 @@ def main(): ...@@ -142,27 +153,37 @@ def main():
else: else:
H2O = None H2O = None
# Squeeze the data/h2o (e.g. if from NIFTI single voxel)
FID = np.squeeze(FID)
if H2O is not None:
H2O = np.squeeze(H2O)
# Collect useful info # Collect useful info
if args.central_frequency is not None: if args.central_frequency is not None:
cf = args.central_frequency cf = args.central_frequency
elif dataheader['centralFrequency'] is not None: elif dataheader['centralFrequency'] is not None:
cf = dataheader['centralFrequency'] cf = dataheader['centralFrequency']
if args.verbose:
print(' Detected central frequency in header info cf = {} MHz'.format(cf*1E-6))
else: else:
raise(Exception('Cannot determine central frequency. Please either set it on include it in data header')) raise(Exception('Cannot determine central frequency. Please either set it or include it in data header'))
if args.dwell_time is not None: if args.dwell_time is not None:
bw = 1/args.dwell_time bw = 1/args.dwell_time
elif dataheader['bandwidth'] is not None: elif dataheader['bandwidth'] is not None:
bw = dataheader['bandwidth'] bw = dataheader['bandwidth']
if args.verbose:
print(' Detected bandwidth in header info bw = {} Hz'.format(bw))
else: else:
raise(Exception('Cannot determine central frequency. Please either set it on include it in data header')) raise(Exception('Cannot determine bandwidth. Please either set it or include it in data header'))
# Resample basis? # Resample basis?
if bw != basisheader['bandwidth']: if basisheader is not None:
dwell = 1/basisheader['bandwidth'] if bw != basisheader['bandwidth']:
new_dwell = 1/bw dwell = 1/basisheader['bandwidth']
basis = misc.resample_ts(basis,dwell,new_dwell) new_dwell = 1/bw
basis = misc.resample_ts(basis,dwell,new_dwell)
# Instantiate MRS object # Instantiate MRS object
...@@ -170,25 +191,24 @@ def main(): ...@@ -170,25 +191,24 @@ def main():
mrs = MRS(**MRSargs) mrs = MRS(**MRSargs)
# Check the FID # Check the FID
mrs.check_FID(repare=True) conjugated = mrs.check_FID(repare=True)
if args.verbose:
if conjugated == 1:
raise(Warning('Warning :: FID has been checked and conjugated. Please check!'))
# Do phase correction # Do phase correction
mrs.FID = misc.phase_correct(mrs.FID) if args.phase_correct:
mrs.Spec = np.fft.fft(mrs.FID) if args.verbose:
print('--->> Phase correction\n')
if args.verbose: mrs.FID = misc.phase_correct(mrs,mrs.FID)
print('--->> Read input data and basis\n') mrs.Spec = np.fft.fft(mrs.FID)
print(' {}\n'.format(args.data))
print(' {}\n'.format(args.basis))
# Keep/Ignore/Combine metabolites # Keep/Ignore/Combine metabolites
mrs.keep(args.keep) mrs.keep(args.keep)
mrs.ignore(args.ignore) mrs.ignore(args.ignore)
mrs.combine(args.combine) mrs.combine(args.combine)
# Do the fitting here # Do the fitting here
if args.verbose: if args.verbose:
print('--->> Start fitting\n\n') print('--->> Start fitting\n\n')
...@@ -197,12 +217,12 @@ def main(): ...@@ -197,12 +217,12 @@ def main():
# Do the fitting # Do the fitting
print(' --- Run fitting --- ') if args.verbose:
print(' --- Run fitting --- ')
ppmlim=args.ppmlim ppmlim=args.ppmlim
if ppmlim is not None: if ppmlim is not None:
ppmlim=(ppmlim[0],ppmlim[1]) ppmlim=(ppmlim[0],ppmlim[1])
#mrs.fit(model = args.model, method=args.algo,ppmlim=args.ppmlim)
# Parse metabolite groups # Parse metabolite groups
...@@ -214,16 +234,23 @@ def main(): ...@@ -214,16 +234,23 @@ def main():
# Include Macromolecules? These should have their own metab groups # Include Macromolecules? These should have their own metab groups
if args.add_MM is not None: if args.add_MM is not None:
if not args.verbose:
print('Adding macromolecules')
nMM = mrs.add_MM_peaks() nMM = mrs.add_MM_peaks()
G = [i+max(metab_groups)+1 for i in range(nMM)] G = [i+max(metab_groups)+1 for i in range(nMM)]
metab_groups += G metab_groups += G
Fitargs = {'mrs':mrs,'ppmlim':ppmlim, Fitargs = {'ppmlim':ppmlim,
'method':args.algo,'baseline_order':args.baseline_order, 'method':args.algo,'baseline_order':args.baseline_order,
'metab_groups':metab_groups} 'metab_groups':metab_groups}
res = fitting.fit_FSLModel(**Fitargs) if args.verbose:
print(mrs)
print('Fitting args:')
print(Fitargs)
res = fitting.fit_FSLModel(mrs,**Fitargs)
...@@ -239,9 +266,16 @@ def main(): ...@@ -239,9 +266,16 @@ def main():
print('--->> Saving output files to {}\n'.format(args.output)) print('--->> Saving output files to {}\n'.format(args.output))
res.to_file(mrs,filename=os.path.join(args.output,'results_table.csv')) res.to_file(filename=os.path.join(args.output,'results_table.csv'),mrs=mrs,what='concentrations')
res.to_file(filename=os.path.join(args.output,'qc.csv'),what='qc')
res.to_file(filename=os.path.join(args.output,'all_parameters.csv'),what='parameters')
# Save image of MRS voxel
if args.t1 is not None:
datatype = mrs_io.check_datatype(args.data)
if datatype == 'NIFTI':
fig = plotting.plot_world_orient(args.t1,args.data)
fig.savefig(os.path.join(args.output,'voxel_location.png'))
# Create short HTML report # Create short HTML report
#fig = plotting.plotly_fit(mrs,res,ppmlim=ppmlim,proj='abs') #fig = plotting.plotly_fit(mrs,res,ppmlim=ppmlim,proj='abs')
...@@ -250,13 +284,14 @@ def main(): ...@@ -250,13 +284,14 @@ def main():
# Creat HTML report # Creat HTML report
report.create_report(mrs,res, if args.report:
filename=os.path.join(args.output,'report.html'), report.create_report(mrs,res,
fidfile=args.data, filename=os.path.join(args.output,'report.html'),
basisfile=args.basis, fidfile=args.data,
h2ofile=args.h2o, basisfile=args.basis,
outdir=args.output, h2ofile=args.h2o,
date=datetime.datetime.now().strftime("%Y-%m-%d %H:%M")) outdir=args.output,
date=datetime.datetime.now().strftime("%Y-%m-%d %H:%M"))
report.fitting_summary_fig(mrs,res, report.fitting_summary_fig(mrs,res,
......
...@@ -60,6 +60,8 @@ fitting_args.add_argument('--add_MM',action="store_true", ...@@ -60,6 +60,8 @@ fitting_args.add_argument('--add_MM',action="store_true",
# ADDITONAL OPTIONAL ARGUMENTS # ADDITONAL OPTIONAL ARGUMENTS
optional.add_argument('--conjugate',action="store_true",
help='conjugate the data')
optional.add_argument('--single_proc',action="store_true", optional.add_argument('--single_proc',action="store_true",
help='do not run in parallel') help='do not run in parallel')
optional.add_argument('--report',action="store_true", optional.add_argument('--report',action="store_true",
...@@ -143,6 +145,8 @@ numBasis = basis.shape[1] ...@@ -143,6 +145,8 @@ numBasis = basis.shape[1]
# Get array data # Get array data
data = np.asanyarray(data_hdr.dataobj) data = np.asanyarray(data_hdr.dataobj)
if args.conjugate:
data = np.conj(data)
if args.h2o is not None: if args.h2o is not None:
h2o = np.asanyarray(h2o_hdr.dataobj) h2o = np.asanyarray(h2o_hdr.dataobj)
mask = np.asanyarray(mask_hdr.dataobj) mask = np.asanyarray(mask_hdr.dataobj)
...@@ -191,7 +195,7 @@ global_counter = mp.Value('L') ...@@ -191,7 +195,7 @@ global_counter = mp.Value('L')
# Define some ugly local functions for parallel processing # Define some ugly local functions for parallel processing
def runvoxel(FIDH2O,MRSargs,Fitargs): def runvoxel(FIDH2O,MRSargs,Fitargs):
mrs = MRS(FID=FIDH2O[0],H2O=FIDH2O[1],**MRSargs) mrs = MRS(FID=FIDH2O[0],H2O=FIDH2O[1],**MRSargs)
mrs.check_FID(repare=True) #mrs.check_FID(repare=True)
if args.add_MM: if args.add_MM:
n = mrs.add_MM_peaks() n = mrs.add_MM_peaks()
new_metab_groups = [i+max(metab_groups)+1 for i in range(n)] new_metab_groups = [i+max(metab_groups)+1 for i in range(n)]
...@@ -222,9 +226,9 @@ if args.single_proc: ...@@ -222,9 +226,9 @@ if args.single_proc:
results = [] results = []
for idx,FID in enumerate(fid_list): for idx,FID in enumerate(fid_list):
if args.verbose: if args.verbose:
print('{}/{} voxels fitted'.format(idx,len(fid_list)),end='\r') print('{}/{} voxels fitted'.format(idx,len(fid_list))) #,end='\r')
mrs = MRS(FID=FID,H2O=h2o_list[idx],**MRSargs) mrs = MRS(FID=FID,H2O=h2o_list[idx],**MRSargs)
mrs.check_FID(repare=True) #mrs.check_FID(repare=True)
n = mrs.add_MM_peaks() n = mrs.add_MM_peaks()
new_metab_groups = [i+max(metab_groups)+1 for i in range(n)] new_metab_groups = [i+max(metab_groups)+1 for i in range(n)]
new_metab_groups = metab_groups + new_metab_groups new_metab_groups = metab_groups + new_metab_groups
...@@ -306,16 +310,18 @@ if results[0].conc_h2o is not None: ...@@ -306,16 +310,18 @@ if results[0].conc_h2o is not None:
pred_list = [r.pred for r in results] pred_list = [r.pred for r in results]
pred_vol = misc.list_to_volume(pred_list,mask,dtype=np.complex) pred_vol = misc.list_to_volume(pred_list,mask,dtype=np.complex)
# check if data has been conjugated and if so conjugate the predictions # check if data has been conjugated and if so conjugate the predictions
FID = fid_list[0] if args.conjugate:
Spec1 = np.fft.fft(FID) pred_vol = np.conjugate(pred_vol)
Spec2 = np.fft.fft(np.conj(FID)) #FID = fid_list[0]
mrs = MRS(FID=fid_list[0],**MRSargs) #Spec1 = np.fft.fft(FID)
first,last = mrs.ppmlim_to_range(ppmlim) #Spec2 = np.fft.fft(np.conj(FID))
if np.linalg.norm(Spec1[first:last]) < np.linalg.norm(Spec2[first:last]): #mrs = MRS(FID=fid_list[0],**MRSargs)
print('Data has been conjugated') #first,last = mrs.ppmlim_to_range(ppmlim)
pred_vol = np.conj(pred_vol) #if np.linalg.norm(Spec1[first:last]) < np.linalg.norm(Spec2[first:last]):
else: # print('Data has been conjugated')
print('Data has NOT been conjugated') # pred_vol = np.conj(pred_vol)
#else:
# print('Data has NOT been conjugated')
img = nib.Nifti1Image(pred_vol,affine=data_hdr.affine,header=data_hdr.header) img = nib.Nifti1Image(pred_vol,affine=data_hdr.affine,header=data_hdr.header)
filename = os.path.join(os.path.join(args.output,'pred.nii.gz')) filename = os.path.join(os.path.join(args.output,'pred.nii.gz'))
nib.save(img, filename) nib.save(img, filename)
......
...@@ -14,7 +14,7 @@ from fsl_mrs.utils.constants import * ...@@ -14,7 +14,7 @@ from fsl_mrs.utils.constants import *
from fsl_mrs.utils import mh from fsl_mrs.utils import mh
from fsl_mrs.core import MRS from fsl_mrs.core import MRS
from scipy.optimize import minimize from scipy.optimize import minimize,nnls
class FitRes(object): class FitRes(object):
...@@ -83,20 +83,44 @@ class FitRes(object): ...@@ -83,20 +83,44 @@ class FitRes(object):
for i in range(g): for i in range(g):
self.params_names.extend(["eps_{}".format(i)]) self.params_names.extend(["eps_{}".format(i)])
self.params_names.extend(['Phi0','Phi1'])
for i in range(baseline_order+1):
self.params_names.extend(["B_real_{}".format(i)])