Commit c2b64fca authored by William Clarke's avatar William Clarke
Browse files

Changes committed on morning of review. Warning might need to sort out the...

Changes committed on morning of review. Warning might need to sort out the plotting code after messing around in dash.
parent ffef7d04
......@@ -186,18 +186,20 @@ class MRS(object):
This interpolates the basis to match the FID
"""
# RESAMPLE BASIS FUNCTION
bdt = self.basis_dwellTime
bbw = self.basis_bandwidth
bn = self.numBasisPoints
# bdt = self.basis_dwellTime
# bbw = self.basis_bandwidth
# bn = self.numBasisPoints
bt = np.linspace(bdt,bdt*bn,bn)-bdt
fidt = self.timeAxis.flatten()-self.dwellTime
# bt = np.linspace(bdt,bdt*bn,bn)-bdt
# fidt = self.timeAxis.flatten()-self.dwellTime
f = interp1d(bt,self.basis,axis=0)
newiFB = f(fidt)
self.basis = newiFB
# f = interp1d(bt,self.basis,axis=0)
# newiFB = f(fidt)
self.basis = misc.ts_to_ts(self.basis,self.basis_dwellTime,self.dwellTime,self.numPoints)
self.basis_dwellTime = self.dwellTime
self.basis_bandwidth = 1/self.dwellTime
self.numBasisPoints = self.numPoints
# Helper functions
......@@ -217,12 +219,12 @@ class MRS(object):
"""
first,last = self.ppmlim_to_range(ppmlim)
Spec1 = np.real(np.fft.fft(self.FID))[first:last]
Spec2 = np.real(np.fft.fft(np.conj(self.FID)))[first:last]
Spec1 = np.real(misc.FIDToSpec(self.FID))[first:last]
Spec2 = np.real(misc.FIDToSpec(np.conj(self.FID)))[first:last]
if np.linalg.norm(misc.detrend(Spec1,deg=4)) < np.linalg.norm(misc.detrend(Spec2,deg=4)):
if repare is False:
warnings.warn('YOU MAY NEED TO CONJUGATE YOU FID!!!')
warnings.warn('YOU MAY NEED TO CONJUGATE YOUR FID!!!')
return -1
else:
self.conj_FID()
......@@ -237,6 +239,46 @@ class MRS(object):
self.FID = np.conj(self.FID)
self.Spec = misc.FIDToSpec(self.FID)
def check_Basis(self,ppmlim=(.2,4.2),repare=False):
"""
Check if Basis needs to be conjugated
by looking at total power within ppmlim range
Parameters
----------
ppmlim : list
repare : if True applies conjugation to basis
Returns
-------
0 if check successful and -1 if not (also issues warning)
"""
first,last = self.ppmlim_to_range(ppmlim)
conjOrNot = []
for b in self.basis.T:
Spec1 = np.real(misc.FIDToSpec(b))[first:last]
Spec2 = np.real(misc.FIDToSpec(np.conj(b)))[first:last]
if np.linalg.norm(misc.detrend(Spec1,deg=4)) < np.linalg.norm(misc.detrend(Spec2,deg=4)):
conjOrNot.append(1.0)
if (sum(conjOrNot)/len(conjOrNot))>0.5:
if repare is False:
warnings.warn('YOU MAY NEED TO CONJUGATE YOUR BASIS!!!')
return -1
else:
self.conj_Basis()
return 1
return 0
def conj_Basis(self):
"""
Conjugate FID and recalculate spectrum
"""
self.basis = np.conj(self.basis)
def ignore(self,metabs):
"""
Ignore a subset of metabolites by removing them from the basis
......
......@@ -38,11 +38,11 @@ def main():
# REQUIRED ARGUMENTS
required.add_argument('--data',
required=True,type=str,metavar='<str>.RAW',
required=True,type=str,metavar='<str>',
help='input FID file')
required.add_argument('--basis',
required=True,type=str,metavar='<str>',
help='.BASIS file or folder containing basis spectra (will read all .RAW files within)')
help='.BASIS file or folder containing basis spectra (will read all files within)')
required.add_argument('--output',
required=True,type=str,metavar='<str>',
help='output folder')
......@@ -62,8 +62,8 @@ def main():
help='input .H2O file for quantification')
fitting_args.add_argument('--baseline_order',default=2,type=int,metavar=('ORDER'),
help='order of baseline polynomial (default=2)')
fitting_args.add_argument('--metab_groups',default=0,nargs='+',type=int,
help="metabolite groups")
fitting_args.add_argument('--metab_groups',default=0,nargs='+',type=str_or_int_arg,
help="metabolite groups: list of groups or list of names for indept groups.")
fitting_args.add_argument('--add_MM',type=bool,
help="include default macromolecule peaks")
......@@ -83,6 +83,11 @@ def main():
help='do phase correction')
optional.add_argument('--overwrite',action="store_true",
help='overwrite existing output folder')
optional.add_argument('--conj_fid',dest='conjfid',action="store_true",help='Force conjugation of FID')
optional.add_argument('--no_conj_fid',dest='conjfid',action="store_false",help='Forbid automatic conjugation of FID')
optional.add_argument('--conj_basis',dest='conjbasis',action="store_true",help='Force conjugation of basis')
optional.add_argument('--no_conj_basis',dest='conjbasis',action="store_false",help='Forbid automatic conjugation of basis')
optional.set_defaults(conjfid=None,conjbasis=None)
optional.add('--config', required=False, is_config_file=True, help='configuration file')
......@@ -152,19 +157,14 @@ def main():
H2O,_ = mrs_io.read_FID(args.h2o)
else:
H2O = None
# Squeeze the data/h2o (e.g. if from NIFTI single voxel)
FID = np.squeeze(FID)
if H2O is not None:
H2O = np.squeeze(H2O)
# Collect useful info
if args.central_frequency is not None:
cf = args.central_frequency
elif dataheader['centralFrequency'] is not None:
cf = dataheader['centralFrequency']
if args.verbose:
print(' Detected central frequency in header info cf = {} MHz'.format(cf*1E-6))
print(f' Detected central frequency in header info cf = {cf:0.6f} MHz')
else:
raise(Exception('Cannot determine central frequency. Please either set it or include it in data header'))
......@@ -173,35 +173,39 @@ def main():
elif dataheader['bandwidth'] is not None:
bw = dataheader['bandwidth']
if args.verbose:
print(' Detected bandwidth in header info bw = {} Hz'.format(bw))
print(f' Detected bandwidth in header info bw = {bw:0.1f} Hz')
else:
raise(Exception('Cannot determine bandwidth. Please either set it or include it in data header'))
# Resample basis?
if basisheader is not None:
if bw != basisheader['bandwidth']:
dwell = 1/basisheader['bandwidth']
new_dwell = 1/bw
basis = misc.resample_ts(basis,dwell,new_dwell)
# Instantiate MRS object
MRSargs = {'FID':FID,'basis':basis,'names':names,'H2O':H2O,'cf':cf,'bw':bw}
MRSargs = {'FID':FID,'basis':basis,'basis_hdr':basisheader[0],'names':names,'H2O':H2O,'cf':cf,'bw':bw}
mrs = MRS(**MRSargs)
# Check the FID
conjugated = mrs.check_FID(repare=True)
if args.verbose:
if conjugated == 1:
raise(Warning('Warning :: FID has been checked and conjugated. Please check!'))
# Check the FID and basis / conjugate
if args.conjfid is not None:
if args.conjfid:
mrs.conj_FID()
else:
conjugated = mrs.check_FID(repare=True)
if args.verbose:
if conjugated == 1:
warnings.warn('FID has been checked and conjugated. Please check!',UserWarning)
if args.conjbasis is not None:
if args.conjbasis:
mrs.conj_Basis()
else:
conjugated = mrs.check_Basis(repare=True)
if args.verbose:
if conjugated == 1:
warnings.warn('Basis has been checked and conjugated. Please check!',UserWarning)
# Do phase correction
if args.phase_correct:
if args.verbose:
print('--->> Phase correction\n')
mrs.FID = misc.phase_correct(mrs,mrs.FID)
mrs.Spec = np.fft.fft(mrs.FID)
mrs.Spec = misc.FIDToSpec(mrs.FID)
# Keep/Ignore/Combine metabolites
......@@ -227,10 +231,19 @@ def main():
# Parse metabolite groups
metab_groups = args.metab_groups
if metab_groups == 0:
metab_groups = [0]*mrs.numBasis
if len(metab_groups) != mrs.numBasis:
raise(Exception('Found {} metab_groups but there are {} basis functions'.format(len(metab_groups),mrs.numBasis)))
if isinstance(metab_groups,list):
if isinstance(metab_groups[0],str):
tmp = [0]*mrs.numBasis
grpcounter = 0
for n in metab_groups:
grpcounter += 1
tmp[mrs.names.index(n)] = grpcounter
metab_groups = tmp
elif isinstance(metab_groups[0],int):
if metab_groups == [0]:
metab_groups = [0]*mrs.numBasis
elif len(metab_groups) != mrs.numBasis:
raise(Exception('Found {} metab_groups but there are {} basis functions'.format(len(metab_groups),mrs.numBasis)))
# Include Macromolecules? These should have their own metab groups
if args.add_MM is not None:
......@@ -300,6 +313,11 @@ def main():
if args.verbose:
print('\n\n\nDone.')
def str_or_int_arg(x):
try:
return int(x)
except:
return x
if __name__ == '__main__':
main()
......@@ -17,7 +17,7 @@ import nibabel as nib
import scipy.ndimage as ndimage
import itertools as it
from fsl_mrs.utils.misc import hz2ppm,FIDToSpec
from fsl_mrs.utils.misc import hz2ppm,FIDToSpec,SpecToFID
def FID2Spec(x):
"""
......@@ -143,7 +143,7 @@ def plot_fit_new(mrs,ppmlim=(0.40,4.2)):
"""
axis = np.flipud(mrs.ppmAxisFlip)
spec = np.flipud(np.fft.fftshift(mrs.Spec))
pred = np.fft.fft(mrs.pred)
pred = FIDToSpec(mrs.pred)
pred = np.flipud(np.fft.fftshift(pred))
if mrs.baseline is not None:
......@@ -378,11 +378,13 @@ def plot_fit_pretty(mrs,pred=None,ppmlim=(0.40,4.2),proj='real'):
return fig
# plotly imports
import pandas as pd
import plotly.graph_objects as go
import plotly.figure_factory as ff
from plotly.subplots import make_subplots
def plotly_fit(mrs,res,ppmlim=(.2,4.2),proj='real'):
def plotly_fit(mrs,res,ppmlim=(.2,4.2),proj='real',metabs = None,phs=(0,0)):
"""
plot model fitting plus baseline
......@@ -390,6 +392,8 @@ def plotly_fit(mrs,res,ppmlim=(.2,4.2),proj='real'):
mrs : MRS object
res : ResFit Object
ppmlim : tuple
metabs : list of metabolite to include in pred
phs : display phasing in degrees and seconds
Returns
fig
......@@ -405,20 +409,40 @@ def plotly_fit(mrs,res,ppmlim=(.2,4.2),proj='real'):
return np.abs(x)
# Prepare the data
base = FID2Spec(res.baseline)
axis = np.flipud(mrs.ppmAxisFlip)
data = project(FID2Spec(mrs.FID),proj)
pred = project(FID2Spec(res.pred),proj)
base = project(FID2Spec(res.baseline),proj)
resid = project(FID2Spec(res.residuals),proj)
data = FID2Spec(mrs.FID)
if metabs is not None:
preds = []
for m in metabs:
preds.append(FID2Spec(pred(mrs,res,m,add_baseline=False)))
preds = sum(preds)
preds += FID2Spec(res.baseline)
resid = data-preds
else:
preds = FID2Spec(res.pred)
resid = FID2Spec(res.residuals)
# phasing
faxis = np.squeeze(mrs.frequencyAxis)
phaseTerm = np.exp(1j*(phs[0]*np.pi/180)) * np.exp(1j*2*np.pi*phs[1]*faxis)
base *= phaseTerm
data *= phaseTerm
preds *= phaseTerm
resid *= phaseTerm
base = project(base,proj)
data = project(data,proj)
preds = project(preds,proj)
resid = project(resid,proj)
# y-axis range
ymin = np.min(data)-np.min(data)/10
ymax = np.max(data)-np.max(data)/30
# Build the plot
import plotly.graph_objects as go
import plotly.figure_factory as ff
import pandas as pd
# Table
......@@ -426,7 +450,12 @@ def plotly_fit(mrs,res,ppmlim=(.2,4.2),proj='real'):
df['Metab'] = mrs.names
df['mMol/kg'] = np.round(res.conc_h2o,decimals=2)
df['%CRLB'] = np.round(res.perc_SD[:mrs.numBasis],decimals=1)
df['/Cr'] = np.round(res.conc_cr,decimals=2)
if res.conc_cr_pcr is not None:
df['/tCr'] = np.round(res.conc_cr_pcr,decimals=2)
elif res.conc_cr is not None:
df['/Cr'] = np.round(res.conc_cr,decimals=2)
else:
df['unscaled'] = np.round(res.conc,decimals=2)
fig = ff.create_table(df, height_constant=50)
......@@ -446,7 +475,7 @@ def plotly_fit(mrs,res,ppmlim=(.2,4.2),proj='real'):
name='data',
line=dict(color=colors['data'],width=line_size['data']),
xaxis='x2', yaxis='y2')
trace2 = go.Scatter(x=axis, y=pred,
trace2 = go.Scatter(x=axis, y=preds,
mode='lines',
name='model',
line=dict(color=colors['pred'],width=line_size['pred']),
......@@ -490,8 +519,6 @@ def plotly_fit(mrs,res,ppmlim=(.2,4.2),proj='real'):
def plot_dist_approx(mrs,res,refname='Cr'):
import plotly.graph_objects as go
from plotly.subplots import make_subplots
n = int(np.ceil(np.sqrt(mrs.numBasis)))
fig = make_subplots(rows=n, cols=n,subplot_titles=mrs.names)
......@@ -520,8 +547,6 @@ def plot_dist_approx(mrs,res,refname='Cr'):
def plot_mcmc_corr(mrs,res):
import plotly.graph_objects as go
from plotly.subplots import make_subplots
#Greys,YlGnBu,Greens,YlOrRd,Bluered,RdBu,Reds,Blues,
#Picnic,Rainbow,Portland,Jet,Hot,Blackbody,Earth,
......@@ -530,6 +555,8 @@ def plot_mcmc_corr(mrs,res):
fig = go.Figure()
corr = np.ma.corrcoef(res.mcmc_samples.T)
np.fill_diagonal(corr,np.nan)
corrabs = np.abs(corr)
fig.add_trace(go.Heatmap(z=corr,
x=mrs.names,y=mrs.names,colorscale='Picnic'))
......@@ -541,7 +568,31 @@ def plot_mcmc_corr(mrs,res):
yaxis = dict(
scaleanchor = "x",
scaleratio = 1,
))
),
updatemenus=[
dict(
type = "buttons",
direction = "left",
buttons=list([
dict(
args=[{"z":[corr],"colorscale":'Picnic'}],
label="Real",
method="restyle"
),
dict(
args=[{"z":[corrabs],"colorscale":'Reds'}],
label="Abs",
method="restyle"
)
]),
pad={"r": 10, "t": 10},
showactive=True,
x=0.11,
xanchor="left",
y=1.1,
yanchor="top"
),
])
return fig
......@@ -576,7 +627,6 @@ def plot_dist_mcmc(mrs,res,refname='Cr'):
return fig
def plot_real_imag(mrs,res,ppmlim=(.2,4.2)):
"""
plot model fitting plus baseline
......@@ -609,9 +659,6 @@ def plot_real_imag(mrs,res,ppmlim=(.2,4.2)):
# Build the plot
import plotly.graph_objects as go
from plotly.subplots import make_subplots
fig = make_subplots(rows=1, cols=2,subplot_titles=['Real','Imag'])
......@@ -666,12 +713,12 @@ def plot_real_imag(mrs,res,ppmlim=(.2,4.2)):
# fig.layout.margin.update({'t':50, 'b':100})
fig.layout.update({'title': 'Fitting summary Real/Imag'})
fig.update_layout(template = 'plotly_white')
fig.layout.update({'height':800,'width':1000})
# fig.layout.update({'height':800,'width':1000})
return fig
def pred(mrs,res,metab):
def pred(mrs,res,metab,add_baseline=True):
from fsl_mrs.utils import models
if res.model == 'lorentzian':
......@@ -694,16 +741,59 @@ def pred(mrs,res,metab):
else:
raise Exception('Unknown model.')
pred = forward(x,mrs.frequencyAxis,
mrs.timeAxis,
mrs.basis,res.base_poly,res.metab_groups,res.g)
pred = np.fft.ifft(pred) # predict FID not Spec
if add_baseline:
pred = forward(x,mrs.frequencyAxis,
mrs.timeAxis,
mrs.basis,res.base_poly,res.metab_groups,res.g)
else:
pred = forward(x,mrs.frequencyAxis,
mrs.timeAxis,
mrs.basis,np.zeros(res.base_poly.shape),res.metab_groups,res.g)
pred = SpecToFID(pred) # predict FID not Spec
return pred
def plot_indiv(mrs,res,ppmlim=(.2,4.2)):
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def plot_indiv_stacked(mrs,res,ppmlim=(.2,4.2)):
colors = dict(data='rgb(67,67,67)',
indiv='rgb(253,59,59)')
line_size = dict(data=.5,
indiv=2)
fig = go.Figure()
axis = np.flipud(mrs.ppmAxisFlip)
y_data = np.real(FID2Spec(mrs.FID))
trace1 = go.Scatter(x=axis, y=y_data,
mode='lines',
name='data',
line=dict(color=colors['data'],width=line_size['data']))
fig.add_trace(trace1)
for i,metab in enumerate(mrs.names):
y_fit = np.real(FID2Spec(pred(mrs,res,metab)))
trace2 = go.Scatter(x=axis, y=y_fit,
mode='lines',
name=metab,
line=dict(color=colors['indiv'],width=line_size['indiv']))
fig.add_trace(trace2)
fig.layout.xaxis.update(title_text='Chemical shift (ppm)',
tick0=2, dtick=.5,
range=[ppmlim[1],ppmlim[0]])
fig.layout.yaxis.update(zeroline=True,
zerolinewidth=1,
zerolinecolor='Gray',
showgrid=False,showticklabels=False)
# Update the margins to add a title and see graph x-labels.
# fig.layout.margin.update({'t':50, 'b':100})
fig.layout.update({'title': 'Individual Fitting summary'})
fig.update_layout(template = 'plotly_white')
# fig.layout.update({'height':800,'width':1000})
return fig
def plot_indiv(mrs,res,ppmlim=(.2,4.2)):
colors = dict(data='rgb(67,67,67)',
pred='rgb(253,59,59)')
line_size = dict(data=.5,
......@@ -751,11 +841,7 @@ def plot_indiv(mrs,res,ppmlim=(.2,4.2)):
showgrid=False,showticklabels=False)
return fig
# def plot_table_extra(mrs,res):
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
# def plot_table_extra(mrs,res):
def plot_table_qc(mrs,res):
# QC measures
header=["S/N","Static phase (deg)", "Linear phase (deg/ppm)"]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment