Commit 23d58004 authored by William Clarke's avatar William Clarke
Browse files

Rewrite of MRS class, tests for MRS and introduction of different nuclei.

parent 80fdf678
This diff is collapsed.
......@@ -192,7 +192,7 @@ class MRSI(object):
jj = j - dim2[0]
ax = axes[ii,jj]
mrs = self.mrs_by_index([i,j,k])
ax.plot(mrs.getAxes(ppmlim=ppmlim),np.real(mrs.getSpectrum(ppmlim=ppmlim)))
ax.plot(mrs.getAxes(ppmlim=ppmlim),np.real(mrs.get_spec(ppmlim=ppmlim)))
ax.invert_xaxis()
ax.set_xticks([])
ax.set_yticks([])
......
......@@ -257,7 +257,6 @@ def main():
if args.verbose:
print('--->> Phase correction\n')
mrs.FID = misc.phase_correct(mrs, mrs.FID)
mrs.Spec = misc.FIDToSpec(mrs.FID)
# Keep/Ignore metabolites
mrs.keep(args.keep)
......
from pathlib import Path
from fsl_mrs.core import MRS
import pytest
from fsl_mrs.utils import synthetic as syn
import numpy as np
from fsl_mrs.utils.misc import FIDToSpec, hz2ppm
# Files
testsPath = Path(__file__).parent
svs_metab = testsPath / 'testdata/fsl_mrs/metab.nii'
svs_water = testsPath / 'testdata/fsl_mrs/water.nii'
svs_basis = testsPath / 'testdata/fsl_mrs/steam_basis'
@pytest.fixture
def synth_data():
fid, hdr = syn.syntheticFID()
hdr['json'] = {'ResonantNucleus': '1H'}
basis_1, bhdr_1 = syn.syntheticFID(noisecovariance=[[0.0]],
chemicalshift=[-2, ],
amplitude=[0.1, ],
damping=[5, ])
basis_2, bhdr_2 = syn.syntheticFID(noisecovariance=[[0.0]],
chemicalshift=[3, ],
amplitude=[0.1, ],
damping=[5, ])
basis = np.concatenate((basis_1, basis_2))
bheader = [bhdr_1, bhdr_2]
names = ['ppm_2', 'ppm3']
timeAxis = np.linspace(hdr['dwelltime'],
hdr['dwelltime'] * 2048,
2048)
frequencyAxis = np.linspace(-hdr['bandwidth']/2,
hdr['bandwidth']/2,
2048)
ppmAxis = hz2ppm(hdr['centralFrequency']*1E6,
frequencyAxis,
shift=False)
ppmAxisShift = hz2ppm(hdr['centralFrequency']*1E6,
frequencyAxis,
shift=True)
axes = {'time': timeAxis,
'freq': frequencyAxis,
'ppm': ppmAxis,
'ppm_shift': ppmAxisShift}
return fid[0], hdr, basis, names, bheader, axes
def test_load_from_file():
mrs = MRS()
mrs.from_files(str(svs_metab),
str(svs_basis),
H2O_file=str(svs_water))
assert mrs.FID.shape == (4096,)
assert mrs.basis.shape == (4096, 20)
assert mrs.H2O.shape == (4096,)
def test_load(synth_data):
fid, hdr, basis, names, bheader, axes = synth_data
mrs = MRS(FID=fid,
header=hdr,
basis=basis,
names=names,
basis_hdr=bheader[0])
assert mrs.FID.shape == (2048,)
assert mrs.basis.shape == (2048, 2)
assert mrs.numBasis == 2
assert mrs.dwellTime == 1/4000
assert mrs.centralFrequency == 123E6
assert mrs.nucleus == '1H'
def test_access(synth_data):
fid, hdr, basis, names, bheader, axes = synth_data
mrs = MRS(FID=fid,
header=hdr,
basis=basis,
names=names,
basis_hdr=bheader[0])
assert np.allclose(mrs.FID, fid)
assert np.allclose(mrs.get_spec(), FIDToSpec(fid))
assert np.allclose(mrs.basis.T, basis)
assert np.allclose(mrs.getAxes(axis='ppmshift'), axes['ppm_shift'])
assert np.allclose(mrs.getAxes(axis='ppm'), axes['ppm'])
assert np.allclose(mrs.getAxes(axis='freq'), axes['freq'])
assert np.allclose(mrs.getAxes(axis='time'), axes['time'])
mrs.rescaleForFitting()
assert np.allclose(mrs.get_spec()/mrs.scaling['FID'], FIDToSpec(fid))
assert np.allclose(mrs.basis.T/mrs.scaling['basis'], basis)
mrs.conj_Basis()
mrs.conj_FID()
assert np.allclose(mrs.get_spec()/mrs.scaling['FID'],
FIDToSpec(fid.conj()))
assert np.allclose(mrs.basis.T/mrs.scaling['basis'], basis.conj())
def test_basis_manipulations(synth_data):
fid, hdr, basis, names, bheader, axes = synth_data
mrs = MRS(FID=fid,
header=hdr,
basis=basis,
names=names,
basis_hdr=bheader[0])
assert mrs.basis.shape == (2048, 2)
assert mrs.numBasis == 2
mrs.keep(['ppm_2'])
assert mrs.basis.shape == (2048, 1)
assert mrs.numBasis == 1
mrs.add_peak(0, 1, 'test', gamma=10, sigma=10)
assert mrs.basis.shape == (2048, 2)
assert mrs.numBasis == 2
assert mrs.names == ['ppm_2', 'test']
mrs.ignore(['test'])
assert mrs.basis.shape == (2048, 1)
assert mrs.numBasis == 1
mrs.add_MM_peaks(gamma=10, sigma=10)
assert mrs.basis.shape == (2048, 6)
assert mrs.numBasis == 6
......@@ -207,7 +207,7 @@ def test_shiftToRef():
mrs = MRS(FID=shiftFID,header=testHdrs)
maxindex = np.argmax(mrs.getSpectrum(shift=False))
maxindex = np.argmax(mrs.get_spec(shift=False))
position = mrs.getAxes(axis='ppm')[maxindex]
assert np.isclose(position,-2.0,atol=1E-1)
......@@ -16,13 +16,13 @@ def test_calcQC():
synMRS_basis = MRS(FID =synFID[0],header=synHdr,basis =basisFID[0] ,basis_hdr=basisHdr,names=['Peak1'])
truenoiseSD = np.sqrt(synHdrNoise['cov'][0,0])
pureNoiseMeasured = np.std(synMRSNoise.getSpectrum())
realnoise = np.std(np.real(synMRSNoise.getSpectrum()))
imagNoise = np.std(np.imag(synMRSNoise.getSpectrum()))
pureNoiseMeasured = np.std(synMRSNoise.get_spec())
realnoise = np.std(np.real(synMRSNoise.get_spec()))
imagNoise = np.std(np.imag(synMRSNoise.get_spec()))
print(f'True cmplx noise = {truenoiseSD:0.3f}, pure noise measured = {pureNoiseMeasured:0.3f} (real/imag = {realnoise:0.3f}/{imagNoise:0.3f})')
# Calc SNR without apodisation from the no noise and pure noise spectra
truePeakHeight = np.max(np.real(synMRSNoNoise.getSpectrum()))
truePeakHeight = np.max(np.real(synMRSNoNoise.get_spec()))
SNR_noApod = truePeakHeight/pureNoiseMeasured
print(f'SNR no apod: {SNR_noApod:0.1f} ({truePeakHeight:0.2e}/{pureNoiseMeasured:0.2e})')
......
......@@ -8,8 +8,22 @@
# Copyright (C) 2019 University of Oxford
# SHBASECOPYRIGHT
# From https://en.wikipedia.org/wiki/Gyromagnetic_ratio
# except for 1H https://physics.nist.gov/cgi-bin/cuu/Value?gammappbar
# MHz/tesla
H1_gamma = 42.576
GYRO_MAG_RATIO = {'1H': H1_gamma,
'13C': 10.7084,
'31P': 17.235}
H2O_PPM_TO_TMS = 4.65 # Shift of water to Tetramethylsilane
H1_gamma = 42.576 # MHz/tesla
PPM_SHIFT = {'1H': H2O_PPM_TO_TMS,
'13C': 0.0,
'31P': 0.0}
PPM_RANGE = {'1H': (0.2, 4.2),
'13C': (10, 100),
'31P': (-20, 10)}
# Concentration scaling parameters
TISSUE_WATER_DENSITY = {'GM': 0.78, 'WM': 0.65, 'CSF': 0.97}
......
......@@ -38,7 +38,7 @@ def print_params(x,mrs,metab_groups,ref_metab='Cr',scale_factor=1):
# New strategy for init
def init_params(mrs,baseline,ppmlim):
first,last = mrs.ppmlim_to_range(ppmlim)
y = mrs.getSpectrum(ppmlim=ppmlim)
y = mrs.get_spec(ppmlim=ppmlim)
y = np.concatenate((np.real(y),np.imag(y)),axis=0).flatten()
B = baseline[first:last,:].copy()
B = np.concatenate((np.real(B),np.imag(B)),axis=0)
......@@ -94,7 +94,7 @@ def init_FSLModel(mrs,metab_groups,baseline,ppmlim):
def init_params_voigt(mrs,baseline,ppmlim):
first,last = mrs.ppmlim_to_range(ppmlim)
y = mrs.getSpectrum(ppmlim=ppmlim)
y = mrs.get_spec(ppmlim=ppmlim)
y = np.concatenate((np.real(y),np.imag(y)),axis=0).flatten()
B = baseline[first:last,:].copy()
B = np.concatenate((np.real(B),np.imag(B)),axis=0)
......@@ -286,7 +286,7 @@ def fit_FSLModel(mrs,
elif model.lower() == 'voigt':
init_func = init_FSLModel_Voigt # initialisation of params
data = mrs.Spec.copy() # data copied to keep it safe
data = mrs.get_spec().copy() # data copied to keep it safe
first,last = mrs.ppmlim_to_range(ppmlim) # data range
if metab_groups is None:
......
......@@ -5,7 +5,7 @@
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# William Clarke <william.clarke@ndcn.ox.ac.uk>
#
# Copyright (C) 2019 University of Oxford
# Copyright (C) 2019 University of Oxford
# SHBASECOPYRIGHT
import numpy as np
......@@ -13,32 +13,27 @@ import scipy.fft
from scipy.signal import butter, lfilter
from scipy.interpolate import interp1d
import itertools as it
from copy import deepcopy
from .constants import H2O_PPM_TO_TMS
from .constants import H2O_PPM_TO_TMS
# Convention:
# freq in Hz
# ppm = freq/1e6
# ppm_shift = ppm - 4.65
# why is there a minus sign here?
def ppm2hz(cf,ppm,shift=True):
def ppm2hz(cf, ppm, shift=True, shift_amount=H2O_PPM_TO_TMS):
if shift:
return (ppm-H2O_PPM_TO_TMS)*cf*1E-6
return (ppm-shift_amount)*cf*1E-6
else:
return (ppm)*cf*1E-6
def hz2ppm(cf,hz,shift=True):
def hz2ppm(cf, hz, shift=True, shift_amount=H2O_PPM_TO_TMS):
if shift:
return 1E6 *hz/cf + H2O_PPM_TO_TMS
return 1E6 * hz/cf + shift_amount
else:
return 1E6 *hz/cf
return 1E6 * hz/cf
def FIDToSpec(FID,axis=0):
def FIDToSpec(FID, axis=0):
""" Convert FID to spectrum
Performs fft along indicated axis
Args:
FID (np.array) : array of FIDs
......@@ -49,16 +44,20 @@ def FIDToSpec(FID,axis=0):
"""
# By convention the first point of the fid is special cased
ss = [slice(None) for i in range(FID.ndim)]
ss[axis] = slice(0,1)
ss = tuple(ss)
FID[ss] *=0.5
out = scipy.fft.fftshift(scipy.fft.fft(FID,axis=axis,norm='ortho'),axes=axis)
FID[ss] *=2
ss[axis] = slice(0, 1)
ss = tuple(ss)
FID[ss] *= 0.5
out = scipy.fft.fftshift(scipy.fft.fft(FID,
axis=axis,
norm='ortho'),
axes=axis)
FID[ss] *= 2
return out
def SpecToFID(spec,axis=0):
def SpecToFID(spec, axis=0):
""" Convert spectrum to FID
Performs fft along indicated axis
Args:
spec (np.array) : array of spectra
......@@ -66,30 +65,40 @@ def SpecToFID(spec,axis=0):
Returns:
x (np.array) : array of FIDs
"""
fid = scipy.fft.ifft(scipy.fft.ifftshift(spec,axes=axis),axis=axis,norm='ortho')
"""
fid = scipy.fft.ifft(scipy.fft.ifftshift(spec,
axes=axis),
axis=axis, norm='ortho')
ss = [slice(None) for i in range(fid.ndim)]
ss[axis] = slice(0,1)
ss = tuple(ss)
ss[axis] = slice(0, 1)
ss = tuple(ss)
fid[ss] *= 2
return fid
def calculateAxes(bandwidth,centralFrequency,points):
def calculateAxes(bandwidth, centralFrequency, points, shift):
dwellTime = 1/bandwidth
timeAxis = np.linspace(dwellTime,
dwellTime*points,
points)
frequencyAxis = np.linspace(-bandwidth/2,
bandwidth/2,
points)
ppmAxis = hz2ppm(centralFrequency,
frequencyAxis,shift=False)
ppmAxisShift = hz2ppm(centralFrequency,
frequencyAxis,shift=True)
return {'time':timeAxis,'freq':frequencyAxis,'ppm':ppmAxis,'ppmshift':ppmAxisShift}
def checkCFUnits(cf,units='Hz'):
timeAxis = np.linspace(dwellTime,
dwellTime * points,
points)
frequencyAxis = np.linspace(-bandwidth/2,
bandwidth/2,
points)
ppmAxis = hz2ppm(centralFrequency,
frequencyAxis,
shift=False)
ppmAxisShift = hz2ppm(centralFrequency,
frequencyAxis,
shift=True,
shift_amount=shift)
return {'time': timeAxis,
'freq': frequencyAxis,
'ppm': ppmAxis,
'ppmshift': ppmAxisShift}
def checkCFUnits(cf, units='Hz'):
""" Check the units of central frequency and adjust if required."""
# Assume cf in Hz > 1E5, if it isn't assume that user has passed in MHz
if cf<1E5:
......
......@@ -144,8 +144,8 @@ def plot_fit_new(mrs,ppmlim=(0.40,4.2)):
mrs : MRS object
ppmlim : tuple
"""
axis = np.flipud(mrs.ppmAxisFlip)
spec = np.flipud(np.fft.fftshift(mrs.Spec))
axis = mrs.ppmAxisShift
spec = np.flipud(np.fft.fftshift(mrs.get_spec()))
pred = FIDToSpec(mrs.pred)
pred = np.flipud(np.fft.fftshift(pred))
......@@ -243,7 +243,7 @@ def plot_spectrum(mrs,ppmlim=(0.0,4.5),FID=None,proj='real',c='k'):
f,l = mrs.ppmlim_to_range(ppmlim)
data = FIDToSpec(FID)[f:l]
else:
data = mrs.getSpectrum(ppmlim=ppmlim)
data = mrs.get_spec(ppmlim=ppmlim)
#m = min(np.real(data))
......@@ -265,6 +265,49 @@ def plot_spectrum(mrs,ppmlim=(0.0,4.5),FID=None,proj='real',c='k'):
return plt.gcf()
def plot_fid(mrs, tlim=None, FID=None, proj='real', c='k'):
''' Plot time domain FID'''
time_axis = mrs.getAxes(axis='time')
if FID is not None:
data = FID
else:
data = mrs.FID
data = getattr(np, proj)(data)
plt.plot(time_axis, data, color=c, linewidth=2)
if tlim is not None:
plt.xlim(tlim)
plt.xlabel('Time (s)')
plt.minorticks_on()
plt.grid(b=True, axis='x', which='major',color='k', linestyle='--', linewidth=.3)
plt.grid(b=True, axis='x', which='minor', color='k', linestyle=':',linewidth=.3)
plt.tight_layout()
return plt.gcf()
def plot_basis(mrs,plot_spec=False,ppmlim=(0.0, 4.5)):
first, last = mrs.ppmlim_to_range(ppmlim=ppmlim)
for idx, n in enumerate(mrs.names):
plt.plot(mrs.getAxes(ppmlim=ppmlim),
np.real(FID2Spec(mrs.basis[:, idx]))[first:last],
label=n)
if plot_spec:
plt.plot(mrs.getAxes(ppmlim=ppmlim),
np.real(mrs.get_spec(ppmlim=ppmlim)),
'k',label='Data')
plt.gca().invert_xaxis()
plt.xlabel('Chemical shift (ppm)')
plt.legend()
return plt.gcf()
def plot_spectra(MRSList,ppmlim=(0,4.5),single_FID=None,plot_avg=True):
plt.figure(figsize=(10,10))
......@@ -278,12 +321,12 @@ def plot_spectra(MRSList,ppmlim=(0,4.5),single_FID=None,plot_avg=True):
avg=0
for mrs in MRSList:
data = np.real(mrs.getSpectrum(ppmlim=ppmlim))
data = np.real(mrs.get_spec(ppmlim=ppmlim))
ppmAxisShift = mrs.getAxes(ppmlim=ppmlim)
avg += data
plt.plot(ppmAxisShift,data,color='k',linewidth=.5,linestyle='-')
if single_FID is not None:
data = np.real(single_FID.getSpectrum(ppmlim=ppmlim))
data = np.real(single_FID.get_spec(ppmlim=ppmlim))
plt.plot(ppmAxisShift,data,color='r',linewidth=2,linestyle='-')
if plot_avg:
avg /= len(MRSList)
......@@ -407,7 +450,7 @@ def plotly_fit(mrs,res,ppmlim=(.2,4.2),proj='real',metabs = None,phs=(0,0)):
# Prepare the data
base = FID2Spec(res.baseline)
axis = np.flipud(mrs.ppmAxisFlip)
axis = mrs.ppmAxisShift
data = FID2Spec(mrs.FID)
if ppmlim is None:
......@@ -667,7 +710,7 @@ def plot_real_imag(mrs,res,ppmlim=(.2,4.2)):
return np.abs(x)
# Prepare the data
axis = np.flipud(mrs.ppmAxisFlip)
axis = mrs.ppmAxisShift
data_real = project(FID2Spec(mrs.FID),'real')
pred_real = project(FID2Spec(res.pred),'real')
data_imag = project(FID2Spec(mrs.FID),'imag')
......@@ -777,7 +820,7 @@ def plot_indiv_stacked(mrs,res,ppmlim=(.2,4.2)):
line_size = dict(data=.5,
indiv=2)
fig = go.Figure()
axis = np.flipud(mrs.ppmAxisFlip)
axis = mrs.ppmAxisShift
y_data = np.real(FID2Spec(mrs.FID))
trace1 = go.Scatter(x=axis, y=y_data,
mode='lines',
......@@ -821,7 +864,7 @@ def plot_indiv(mrs,res,ppmlim=(.2,4.2)):
fig = make_subplots(rows=nrows, cols=ncols,subplot_titles=mrs.names)
traces = []
axis = np.flipud(mrs.ppmAxisFlip)
axis = mrs.ppmAxisShift
for i,metab in enumerate(mrs.names):
c,r = i%ncols,i//ncols
#r = i//ncols
......
......@@ -273,7 +273,7 @@ def phase_freq_align_report(inFIDs,outFIDs,hdr,phi,eps,ppmlim=None,shift=True,ht
def addline(fig,mrs,lim,name,linestyle):
trace = go.Scatter(x=mrs.getAxes(ppmlim=lim, axis=axis),
y=np.real(mrs.getSpectrum(ppmlim=lim, shift=shift)),
y=np.real(mrs.get_spec(ppmlim=lim, shift=shift)),
mode='lines',
name=name,
line=linestyle)
......@@ -399,7 +399,7 @@ def phase_freq_align_diff_report(inFIDs0,inFIDs1,outFIDs0,outFIDs1,hdr,eps,phi,p
def addline(fig,mrs,lim,name,linestyle):
trace = go.Scatter(x=mrs.getAxes(ppmlim=lim, axis=axis),
y=np.real(mrs.getSpectrum(ppmlim=lim, shift=shift)),
y=np.real(mrs.get_spec(ppmlim=lim, shift=shift)),
mode='lines',
name=name,
line=linestyle)
......
......@@ -205,7 +205,7 @@ def combine_FIDs_report(inFIDs,outFID,hdr,ncha=2,ppmlim = (0.0,6.0),method='not
def addline(fig,mrs,lim,name,linestyle):
trace = go.Scatter(x=mrs.getAxes(ppmlim=lim),
y=np.real(mrs.getSpectrum(ppmlim=lim)),
y=np.real(mrs.get_spec(ppmlim=lim)),
mode='lines',
name=name,
line=linestyle)
......@@ -293,8 +293,8 @@ def combine_FIDs_report(inFIDs,outFID,hdr,ncha=2,ppmlim = (0.0,6.0),method='not
# style = ['--']*len(colors)
# ax.set_prop_cycle(color =colors,linestyle=style)
# for fid in toPlotIn:
# ax.plot(fid.getAxes(ppmlim=ppmlim),np.real(fid.getSpectrum(ppmlim=ppmlim)))
# ax.plot(fid.getAxes(ppmlim=ppmlim),np.real(fid.get_spec(ppmlim=ppmlim)))
# for fid in toPlotOut:
# ax.plot(fid.getAxes(ppmlim=ppmlim),np.real(fid.getSpectrum(ppmlim=ppmlim)),'k-')
# ax.plot(fid.getAxes(ppmlim=ppmlim),np.real(fid.get_spec(ppmlim=ppmlim)),'k-')
# styleSpectrumAxes(ax)
# plt.show()
......@@ -49,7 +49,7 @@ def eddy_correct_report(inFID,outFID,phsRef,hdr,ppmlim = (0.2,4.2),html=None):
# Add lines to figure
def addline(fig,mrs,lim,name,linestyle):
y = np.real(mrs.getSpectrum(ppmlim=lim))
y = np.real(mrs.get_spec(ppmlim=lim))
trace = go.Scatter(x=mrs.getAxes(ppmlim=lim),
y=y,
mode='lines',
......
......@@ -56,7 +56,7 @@ def apodize_report(inFID,outFID,hdr,plotlim = (0.2,6),html=None):
# Add lines to figure
def addline(fig,mrs,lim,name,linestyle):
trace = go.Scatter(x=mrs.getAxes(ppmlim=lim),
y=np.real(mrs.getSpectrum(ppmlim=lim)),
y=np.real(mrs.get_spec(ppmlim=lim)),
mode='lines',
name=name,
line=linestyle)
......
......@@ -71,7 +71,7 @@ def add_subtract_report(inFID,inFID2,outFID,hdr,ppmlim=(0.2,4.2),function='Not s
# Add lines to figure
def addline(fig,mrs,lim,name,linestyle):
trace = go.Scatter(x=mrs.getAxes(ppmlim=lim),
y=np.real(mrs.getSpectrum(ppmlim=lim)),
y=np.real(mrs.get_spec(ppmlim=lim)),
mode='lines',
name=name,
line=linestyle)
......@@ -137,12 +137,12 @@ def generic_report(inFID,outFID,inHdr,outHdr,ppmlim = (0.2,4.2),html=None,functi
# Add lines to figure
trace1 = go.Scatter(x=plotIn.getAxes(ppmlim=ppmlim),
y=np.real(plotIn.getSpectrum(ppmlim=ppmlim)),
y=np.real(plotIn.get_spec(ppmlim=ppmlim)),
mode='lines',
name='Original',
line=lines['in'])
trace2 = go.Scatter(x=plotOut.getAxes(ppmlim=ppmlim),
y=np.real(plotOut.getSpectrum(ppmlim=ppmlim)),
y=np.real(plotOut.get_spec(ppmlim=ppmlim)),
mode='lines',
name='Shifted',
line=lines['out'])
......