Commit 218a87ae authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

lastest with new FSLModel and new interactive reporting

parent a6a27bd0
......@@ -6,3 +6,5 @@ __pycache__
*~
/build
/dist
test
TODO
from .core import *
__version__ = '1.0.0'
......@@ -14,7 +14,7 @@ from fsl_mrs.utils.mh import MH
from fsl_mrs.utils import mrs_io as io
from fsl_mrs.utils import models, misc
from fsl_mrs.utils import plotting
from fsl_mrs.utils.constants import *
import numpy as np
import time
......@@ -30,12 +30,6 @@ import matplotlib.pyplot as plt
## ASK MICHIEL/PAUL HOW BEST TO SET UP GLOBAL VARIABLES
H2O_MOLECULAR_MASS = 18.01528 # g/mol
H2O_Conc = 55.51E3 # mmol/kg
H2O_PPM_TO_TMS = 4.65 # Shift of water to Tetramethylsilane
H2O_to_Cr = 0.4 # Proton ratio
H1_gamma = 42.576 # MHz/tesla
class MRS(object):
"""
......@@ -91,6 +85,7 @@ class MRS(object):
out += ' FID.bandwidth (Hz) = {}\n'.format(self.bandwidth)
out += ' FID.dwelltime (s) = {}\n'.format(self.dwellTime)
out += ' FID.echotime (s) = {}\n'.format(self.echotime)
out += ' Metabolites = {}\n'.format(self.names)
return out
......@@ -125,6 +120,7 @@ class MRS(object):
# turn into column vectors
self.timeAxis = self.timeAxis[:,None]
self.frequencyAxis = self.frequencyAxis[:,None]
self.ppmAxisShift = self.ppmAxisShift[:,None]
# by default, basis setup like data
self.set_acquisition_params_basis(self.dwellTime)
......@@ -229,7 +225,21 @@ class MRS(object):
self.names.pop(idx)
self.basis = np.delete(self.basis,idx,axis=1)
self.numBasis = len(self.names)
def keep(self,metabs):
"""
Keep a subset of metabolites by removing all others from basis
Parameters
----------
metabs: list
"""
if metabs is not None:
metabs = list(set(self.names)-set(metabs))
self.ignore(metabs)
def combine(self,metabs):
"""
......@@ -284,6 +294,7 @@ class MRS(object):
TYPE : string
"""
self.datafile = filename
self.FID, header = self.read(filename,TYPE)
self.numPoints = self.FID.size
......@@ -317,7 +328,7 @@ class MRS(object):
self.basis.append(data)
self.numBasis +=1
self.basis = np.asarray(self.basis).astype(np.complex).T
self.basis = self.basis - self.basis.mean(axis=0)
#self.basis = self.basis - self.basis.mean(axis=0)
def read_basis_from_folder(self,folder,TYPE='RAW',ignore=[]):
"""
......@@ -391,8 +402,7 @@ class MRS(object):
if real:
data = np.append(np.real(self.FID),np.imag(self.FID),axis=0)
desmat = np.append(np.real(self.basis),np.imag(self.basis),axis=0)
beta = np.real(np.linalg.pinv(desmat)@data)
print(beta)
beta = np.real(np.linalg.pinv(desmat)@data)
else:
beta = np.linalg.pinv(self.basis)@self.FID
......@@ -478,60 +488,6 @@ class MRS(object):
return x
def calc_baseline(self,spec=None,ppmlim=(0,4.6),order=10):
"""
Estimate baseline
parameters
----------
spec : array-like
spectrum to use for estimating baseline. default: uses self.Spec
ppmlim : tuple
upper and lower limit over which spectrum is calculated
order : integer
order of polynomial used to estimate baseline
"""
# Get axes
axis = np.flipud(self.ppmAxisFlip)
first = np.argmin(np.abs(axis-ppmlim[0]))
last = np.argmin(np.abs(axis-ppmlim[1]))
if first>last:
first,last = last,first
freq = axis[first:last]
# Build design matrix
desmat = []
for i in range(order+1):
regressor = freq**i # power
if i>0:
regressor -= np.mean(regressor) # demean
regressor /= np.linalg.norm(regressor) # normalise
desmat.append(regressor.flatten())
desmat = np.asarray(desmat).T
# Append basis to design matrix so it doesn't
# model out good signal
# First, do a quick nonlinear fit:
self.fit_LCModel(method='Newton',ppmlim=ppmlim)
basis = np.exp(-1j*(self.phi0+self.phi1*self.frequencyAxis))*np.fft.fft(self.basis*np.exp(-(self.gamma+1j*self.eps)*self.timeAxis),axis=0)
basis = np.flipud(np.fft.fftshift(basis))
basis = basis[first:last,:]
desmat = np.concatenate((desmat,basis),axis=1)
if spec is None:
spec = self.Spec
spec = np.flipud(np.fft.fftshift(spec))
beta = np.matmul(np.linalg.pinv(desmat),spec[first:last])
# Model is:
# data = [nuisance basis]*beta
# so baseline = nuisance*beta[:order+1]
baseline = np.zeros(self.numPoints,dtype='complex')
baseline[first:last] = np.matmul(desmat[:,:order+1],beta[:order+1])
baseline = np.flipud(baseline)
baseline = np.fft.fftshift(baseline)
return baseline
def reset_params(self,x):
"""
Set params and recalculate model prediction
......@@ -641,7 +597,7 @@ class MRS(object):
def save_results_to_file(self,filename):
"""
Write concentrations (abs and relative) to text file
Write concentrations to text file
"""
header = 'metabolite,Conc,/Cr+PCr\n'
with open(filename,'w') as f:
......@@ -651,42 +607,6 @@ class MRS(object):
f.write('{},{},{}\n'.format(x,y,z))
def save_fit_to_figure(self,filename,ppmlim=(.4,4.2)):
"""
Save fit to figure
"""
if self.pred is None:
raise Exception('Cannot plot fit before fitting')
axis = np.flipud(self.ppmAxisFlip)
spec = np.flipud(np.fft.fftshift(self.Spec))
pred = np.fft.fft(self.pred)
pred = np.flipud(np.fft.fftshift(pred))
if self.baseline is not None:
B = np.flipud(np.fft.fftshift(self.baseline))
first = np.argmin(np.abs(axis-ppmlim[0]))
last = np.argmin(np.abs(axis-ppmlim[1]))
if first>last:
first,last = last,first
freq = axis[first:last]
plt.figure(figsize=(9,10))
plt.plot(axis[first:last],spec[first:last])
plt.gca().invert_xaxis()
plt.plot(axis[first:last],pred[first:last],'r')
if self.baseline is not None:
plt.plot(axis[first:last],B[first:last],'k')
# style stuff
plt.minorticks_on()
plt.grid(b=True, axis='x', which='major',color='k', linestyle='--', linewidth=.3)
plt.grid(b=True, axis='x', which='minor', color='k', linestyle=':',linewidth=.3)
# Save to file
plt.savefig(filename)
return plt.gcf()
......@@ -719,7 +639,7 @@ class MRS_quantif(object):
return con
else:
warnings.warn("[{}]=0!! Something went wrong somewhere. Can't rescale concentrations...".format(metab))
return self.con
return None
def rescale_to_metab_grp(self,metab_list,scale=1.0):
......@@ -768,11 +688,11 @@ class MRS_quantif(object):
# Use Cr+PCr
interval = np.ones(self.Cr.size)
if ppmaxisshift is not None:
interval[ppmaxisshift<2.5] = 0
interval[ppmaxisshift>4.5] = 0
self.Cr = self.Cr*interval
self.Pcr = self.PCr*interval
#if ppmaxisshift is not None:
# interval[ppmaxisshift<2.5] = 0
# interval[ppmaxisshift>4.5] = 0
# self.Cr = self.Cr*interval
# self.Pcr = self.PCr*interval
Cr_area = np.sum(np.abs(self.con_names['Cr']*self.Cr))/interval.sum()
PCr_area = np.sum(np.abs(self.con_names['PCr']*self.PCr))/interval.sum()
......
#!/usr/bin/env python
# newcore.py - main MRS classes / functions definition
#
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
#
# Copyright (C) 2019 University of Oxford
# SHBASECOPYRIGHT
import numpy as np
from fsl_mrs.utils import misc
from fsl_mrs.utils import mrs_io
from scipy.interpolate import interp1d
def resample(basis,fid):
'''
Resample basis signal to match fid sampling rate
'''
bdt = basis._dwellTime
bbw = 1/bdt
bn = basis._numPoints
bt = np.linspace(bdt,bdt*bn,bn)-bdt
fidt = fid._timeAxis.flatten()-fid._dwellTime
f = interp1d(bt,basis._FID,axis=0)
newiFB = f(fidt)
new_basis = basis
new_basis._FID = newiFB
new_basis._Spec = np.fft.fft(new_basis._FID)
return new_basis
class FID(object):
# Data
_FID = None
_Spec = None
# Properties
_dwellTime = None
_numPoints = None
_centralFrequency = None
_bandwidth = None
def __init__(object,filename):
# Read data
_FID,header = mrs_io.readLCModelRaw(filename)
# Set internal parameters based on the header information
_numPoints = _FID.size
if header['centralFrequency'] is None:
self._centralFrequency = 123.2E6
warnings.warn('Cannot determine central Frequency from input. Setting to default of 123.2E6 Hz (3T)')
if header['bandwidth'] is None:
self._bandwidth = 4000
warnings.warn('Cannot determine bandwidth. Setting to default of 4000Hz.')
if header['echotime'] is None:
self._echotime = 30e-3
warnings.warn('Cannot determine echo time. Setting to default of 30ms.')
self._dwellTime = 1/self._bandwidth;
self._timeAxis = np.linspace(self.dwellTime,
self.dwellTime*self.numPoints,
self.numPoints)
self._frequencyAxis = np.linspace(-self.bandwidth/2,
self.bandwidth/2,
self.numPoints)
self._ppmAxis = misc.hz2ppm(self.centralFrequency,
self.frequencyAxis,shift=False)
self._ppmAxisShift = misc.hz2ppm(self.centralFrequency,
self.frequencyAxis,shift=True)
self._ppmAxisFlip = np.flipud(self.ppmAxisShift)
# turn into column vectors
self._timeAxis = self._timeAxis[:,None]
self._frequencyAxis = self._frequencyAxis[:,None]
class Basis(object,FID):
def __init__(object,filename):
_FID,header = readLCModelBasis(filename)
class H2O(object,FID):
def __init__(object,filename):
_FID,header = readLCModelRaw(filename)
......@@ -7,14 +7,12 @@
# Copyright (C) 2019 University of Oxford
# SHBASECOPYRIGHT
# Quick imports
#import argparse
import configargparse
import argparse
import time, os, sys, shutil, warnings
import numpy as np
from fsl_mrs import MRS
from fsl_mrs.utils import mrs_io
from fsl_mrs.utils import report
import datetime
from fsl_mrs import __version__
from fsl_mrs.utils.splash import splash
# TODO:
......@@ -24,42 +22,55 @@ import datetime
def main():
# Parse command-line arguments
p = argparse.ArgumentParser(description='FSL Magnetic Resonance Spectroscopy Tool')
p = configargparse.ArgParser(add_config_file_help=False,description="FSL Magnetic Resonance Spectroscopy Wrapper Script")
# utility for hiding certain arguments
def hide_args(arglist):
for action in arglist:
action.help=p.SUPPRESS
#p = argparse.ArgumentParser(description='FSL Magnetic Resonance Spectroscopy Tool')
p.add_argument('-v','--version', action='version', version=__version__)
required = p.add_argument_group('required arguments')
fitting_args = p.add_argument_group('fitting options')
optional = p.add_argument_group('additional options')
# REQUIRED ARGUMENTS
required.add_argument('-d','--data',
required.add_argument('--data',
required=True,type=str,metavar='<str>.RAW',
help='input FID file')
required.add_argument('-b','--basis',
required.add_argument('--basis',
required=True,type=str,metavar='<str>',
help='.BASIS file or folder containing basis spectra (will read all .RAW files within)')
required.add_argument('-o','--output',
required.add_argument('--output',
required=True,type=str,metavar='<str>',
help='output folder')
# FITTING ARGUMENTS
fitting_args.add_argument('--model',default='LCModel',type=str,
help='model [default=LCModel]')
fitting_args.add_argument('--algo',default='Newton',type=str,
help='algorithm [Newton or Powell or MH]')
help='algorithm [Newton (fast) or MH (slow)]')
fitting_args.add_argument('--ignore',type=str,nargs='+',metavar='METAB',
help='ignore certain metabolites [repeatable]')
fitting_args.add_argument('--keep',type=str,nargs='+',metavar='METAB',
help='only keep these metabolites')
fitting_args.add_argument('--combine',type=str,nargs='+',action='append',metavar='METAB',
help='combine certain metabolites')
fitting_args.add_argument('--ppmlim',default=None,type=float,nargs=2,metavar=('LOW','HIGH'),
help='limit the fit to a freq range')
help='combine certain metabolites [repeatable]')
fitting_args.add_argument('--ppmlim',default=(.2,4.2),type=float,nargs=2,metavar=('LOW','HIGH'),
help='limit the fit to a freq range (default=(.2,4.2))')
fitting_args.add_argument('--h2o',type=str,metavar='H2O',
help='input .H2O file for quantification')
fitting_args.add_argument('--T2s',type=float,nargs=4,metavar=('T2metab','T2GM','T2WM','T2CSF'),
help='T2 values (in milliseconds) for metabolites, and water GM, WM, and CSF (default=3T ref values)')
help='T2 values (in milliseconds) for metabolites, and water GM, WM, and CSF')
fitting_args.add_argument('--volfrac',type=float,nargs=3,metavar=('GM','WM','CSF'),
help='volume fractions of GM, WM, and CSF (in most cases these should add up to one)')
fitting_args.add_argument('--baseline_order',default=2,type=int,metavar=('ORDER'),
help='order of baseline polynomial (default=2)')
fitting_args.add_argument('--metab_groups',default=0,nargs='+',type=int,
help="metabolite groups.")
# ADDITONAL OPTIONAL ARGUMENTS
optional.add_argument('--report',action="store_true",
......@@ -68,16 +79,34 @@ def main():
help='spit out verbose info')
optional.add_argument('--overwrite',action="store_true",
help='overwrite existing output folder')
optional.add('--config', required=False, is_config_file=True, help='configuration file')
# Output kickass splash screen
mrs_io.splash(logo='mrs')
#hide_args([h1,h2])
# Parse command-line arguments
args = p.parse_args()
# Output kickass splash screen
if args.verbose:
splash(logo='mrs')
# ######################################################
# DO THE IMPORTS AFTER PARSING TO SPEED UP HELP DISPLAY
import time, os, sys, shutil, warnings
import numpy as np
from fsl_mrs.core import MRS
from fsl_mrs.utils import mrs_io
from fsl_mrs.utils import report
from fsl_mrs.utils import fitting
from fsl_mrs.utils import plotting
from fsl_mrs.utils import misc
import datetime
import plotly
# ######################################################
# Check if output folder exists
overwrite = args.overwrite
if os.path.exists(args.output):
......@@ -94,7 +123,14 @@ def main():
else:
os.mkdir(args.output)
# Save chosen arguments
with open(os.path.join(args.output,"options.txt"),"w") as f:
f.write(str(args))
f.write("\n--------\n")
f.write(p.format_values())
####### Do the work #######
# Instantiate MRS object
......@@ -130,41 +166,50 @@ def main():
T2s = args.T2s
# Ignore/Combine metabolites
# Keep/Ignore/Combine metabolites
mrs.keep(args.keep)
mrs.ignore(args.ignore)
mrs.combine(args.combine)
# Do the fitting here
if args.verbose:
print('--->> Start fitting\n\n')
print(' Model = [{}] and Algorithm = [{}]\n'.format(args.model,args.algo))
print(' Algorithm = [{}]\n'.format(args.algo))
start = time.time()
# BASELINE REMOVAL
# Remove baseline before fitting
#print(' --- Remove baseline --- ')
#B = mrs.calc_baseline(ppmlim=(0.2,4.4),order=10)
#mrs.Spec = mrs.Spec - B
#mrs.FID = np.fft.ifft(mrs.Spec)
# ENDOF BASELINE REMOVAL
# Do the fitting
print(' --- Run fitting --- ')
ppmlim=args.ppmlim
if ppmlim is not None:
ppmlim=(ppmlim[0],ppmlim[1])
print('mrs.fit(model ={}, method={},ppmlim={})'.format(args.model,args.algo,ppmlim))
mrs.fit(model = args.model, method=args.algo,ppmlim=args.ppmlim)
#mrs.fit(model = args.model, method=args.algo,ppmlim=args.ppmlim)
# Do phase correction
mrs.FID = misc.phase_correct(mrs.FID)
mrs.Spec = np.fft.fft(mrs.FID)
# Parse metabolite groups
metab_groups = args.metab_groups
if metab_groups == 0:
metab_groups = [0]*mrs.numBasis
if len(metab_groups) != mrs.numBasis:
raise(Exception('Found {} metab_groups but there are {} basis functions'.format(len(metab_groups),mrs.numBasis)))
res = fitting.fit_FSLModel(mrs,method=args.algo,
ppmlim=ppmlim,
baseline_order=args.baseline_order,
metab_groups=metab_groups)
stop = time.time()
# Do some extra bits for quantification
mrs.post_process(metab=['Cr','PCr'],scale=8.0,T2s=T2s,volfrac=args.volfrac)
mrs.reference_matebolite = 'Cr+PCr'
print(mrs.all_con_names_h2o)
#mrs.post_process(metab=['Cr','PCr'],scale=8.0,T2s=T2s,volfrac=args.volfrac)
#mrs.reference_matebolite = 'Cr+PCr'
#print(mrs.all_con_names_h2o)
# Report on the fitting
if args.verbose:
......@@ -174,33 +219,27 @@ def main():
if args.verbose:
print('--->> Saving output files to {}\n'.format(args.output))
mrs.save_results_to_file(os.path.join(args.output,'results_table.csv'))
mrs.save_fit_to_figure(os.path.join(args.output,'spectrum_fit.jpg'))
res.to_file(mrs,filename=os.path.join(args.output,'results_table.csv'))
if args.report:
if args.verbose:
print('--->> Creating html report\n')
mrs.fidfile = args.data
mrs.basisfile = args.basis
mrs.h2ofile = args.h2o
mrs.outdir = args.output
mrs.provenance = 'FMRIB Centre'
mrs.date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
report_file = os.path.join(args.output,'report.html')
r = report.MRS_Report(mrs)
r.parse(report_file)
if args.verbose:
print(' {}'.format(report_file))
# Finish
# TMP: quick plot
#for proj in ['real','imag','abs']:
# mrs.plot_fit(out=os.path.join(args.output,'fitted_spectrum_{}.png'.format(proj)),proj=proj)
# Create short HTML report
#fig = plotting.plotly_fit(mrs,res,ppmlim=ppmlim,proj='abs')
#plotly.io.write_html(fig, file=os.path.join(args.output,'short_report.html'))
# Creat HTML report
report.create_report(mrs,res,
filename=os.path.join(args.output,'report.html'),
fidfile=args.data,
basisfile=args.basis,
h2ofile=args.h2o,
outdir=args.output,
date=datetime.datetime.now().strftime("%Y-%m-%d %H:%M"))
if args.verbose:
print('\n\n\nDone.')
......
......@@ -8,16 +8,15 @@
# Copyright (C) 2019 University of Oxford
# SHBASECOPYRIGHT
import os,sys,shutil,time
# Quick imports
import argparse
from fsl_mrs.utils import simu
from fsl_mrs.utils import mrs_io
from fsl_mrs import __version__
from fsl_mrs.utils.splash import splash
def main():
p = argparse.ArgumentParser(description='FSL Magnetic Resonance Spectroscopy Tools')
p.add_argument('-v','--version', action='version', version=__version__)
required = p.add_argument_group('required arguments')