Commit 1d703575 authored by William Clarke's avatar William Clarke
Browse files

Add apodisation, fixed phase and conjugation.

parent 1eaae69b
...@@ -75,6 +75,8 @@ def main(): ...@@ -75,6 +75,8 @@ def main():
' (default=0.2->4.2)') ' (default=0.2->4.2)')
align_group.add_argument('--reference', type=str, required=False, align_group.add_argument('--reference', type=str, required=False,
help='Align to this reference data.') help='Align to this reference data.')
align_group.add_argument('--apod', type=float, default=10,
help='Apodise data to reduce noise (Hz).')
alignparser.set_defaults(func=align) alignparser.set_defaults(func=align)
add_common_args(alignparser) add_common_args(alignparser)
...@@ -228,6 +230,21 @@ def main(): ...@@ -228,6 +230,21 @@ def main():
phaseparser.set_defaults(func=phase) phaseparser.set_defaults(func=phase)
add_common_args(phaseparser) add_common_args(phaseparser)
fixphaseparser = sp.add_parser('fixed_phase', add_help=False,
help='Apply fixed phase to spectrum')
fphase_group = fixphaseparser.add_argument_group('Phase arguments')
fphase_group.add_argument('--file', type=str, nargs='+', required=True,
help='Data file(s) to shift')
fphase_group.add_argument('--p0', type=float,
metavar='<degrees>',
help='Zero order phase (degrees)')
fphase_group.add_argument('--p1', type=float,
default=0.0,
metavar='<seconds>',
help='First order phase (seconds)')
fixphaseparser.set_defaults(func=fixed_phase)
add_common_args(fixphaseparser)
# subtraction - subtraction of FIDs # subtraction - subtraction of FIDs
subtractparser = sp.add_parser('subtract', add_help=False, subtractparser = sp.add_parser('subtract', add_help=False,
help='Subtract two FIDs') help='Subtract two FIDs')
...@@ -246,6 +263,14 @@ def main(): ...@@ -246,6 +263,14 @@ def main():
addparser.set_defaults(func=add) addparser.set_defaults(func=add)
add_common_args(addparser) add_common_args(addparser)
# conj - conjugation
conjparser = sp.add_parser('conj', add_help=False, help='Conjugate fids')
conj_group = conjparser.add_argument_group('Conjugation arguments')
conj_group.add_argument('--file', type=str, nargs='+', required=True,
help='Data file(s) to conjugate')
conj_group.set_defaults(func=conj)
add_common_args(conj_group)
# Parse command-line arguments # Parse command-line arguments
args = p.parse_args() args = p.parse_args()
...@@ -552,6 +577,7 @@ def align(dataobj, args): ...@@ -552,6 +577,7 @@ def align(dataobj, args):
centralFrequency, centralFrequency,
ppmlim=args['ppm'], ppmlim=args['ppm'],
niter=2, niter=2,
apodize=args['apod'],
verbose=False, verbose=False,
target=dataobj[0].reference) target=dataobj[0].reference)
...@@ -881,6 +907,38 @@ def phase(dataobj, args): ...@@ -881,6 +907,38 @@ def phase(dataobj, args):
return dataout return dataout
def fixed_phase(dataobj, args):
dataout = []
for idx, d in enumerate(dataobj):
phased = np.zeros(d.data.shape, dtype=np.complex128)
for ijk in np.ndindex(d.data.shape[:3]):
phased[ijk] = preproc.applyPhase(d.data[ijk],
(np.pi/180.0)*args['p0'])
if args['p1'] != 0.0:
phased[ijk], newDT = preproc.timeshift(
phased[ijk],
d.dataheader['dwelltime'],
args['p1'],
args['p1'])
if args['generateReports'] and \
np.prod(d.data.shape[:3]) == 1 and \
((idx in args['reportIndicies']) or args['allreports']):
from fsl_mrs.utils.preproc.general import generic_report
generic_report(d.data[ijk],
phased[ijk],
d.dataheader,
d.dataheader,
ppmlim=(0.2, 4.2),
html=args['output'],
function='fixed phase')
dataout.append(datacontainer(phased, d.dataheader, d.datafilename))
return dataout
def subtract(dataobj, args): def subtract(dataobj, args):
dataout = [] dataout = []
subtracted = preproc.subtract(dataobj[0].data, dataobj[1].data) subtracted = preproc.subtract(dataobj[0].data, dataobj[1].data)
...@@ -920,5 +978,29 @@ def add(dataobj, args): ...@@ -920,5 +978,29 @@ def add(dataobj, args):
return dataout return dataout
def conj(dataobj, args):
dataout = []
for idx, d in enumerate(dataobj):
conj = np.zeros(d.data.shape, dtype=np.complex128)
for ijk in np.ndindex(d.data.shape[:3]):
conj[ijk] = np.conj(d.data[ijk])
if args['generateReports'] and \
np.prod(d.data.shape[:3]) == 1 and \
((idx in args['reportIndicies']) or args['allreports']):
from fsl_mrs.utils.preproc.general import generic_report
generic_report(d.data[ijk],
conj[ijk],
d.dataheader,
d.dataheader,
ppmlim=(0.2, 4.2),
html=args['output'],
function='conj')
dataout.append(datacontainer(conj, d.dataheader, d.datafilename))
return dataout
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -5,17 +5,19 @@ ...@@ -5,17 +5,19 @@
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk> # Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# William Clarke <william.clarke@ndcn.ox.ac.uk> # William Clarke <william.clarke@ndcn.ox.ac.uk>
# #
# Copyright (C) 2019 University of Oxford # Copyright (C) 2019 University of Oxford
# SHBASECOPYRIGHT # SHBASECOPYRIGHT
from fsl_mrs.utils.preproc.general import get_target_FID,add,subtract from fsl_mrs.utils.preproc.general import get_target_FID, add, subtract
from fsl_mrs.utils.preproc.filtering import apodize as apod
from fsl_mrs.core import MRS from fsl_mrs.core import MRS
from fsl_mrs.utils.misc import extract_spectrum,shift_FID from fsl_mrs.utils.misc import extract_spectrum, shift_FID
from scipy.optimize import minimize from scipy.optimize import minimize
import numpy as np import numpy as np
# Phase-Freq alignment functions # Phase-Freq alignment functions
def align_FID(mrs,src_FID,tgt_FID,ppmlim=None,shift=True): def align_FID(mrs, src_FID, tgt_FID, ppmlim=None, shift=True):
""" """
Phase and frequency alignment Phase and frequency alignment
...@@ -32,24 +34,26 @@ def align_FID(mrs,src_FID,tgt_FID,ppmlim=None,shift=True): ...@@ -32,24 +34,26 @@ def align_FID(mrs,src_FID,tgt_FID,ppmlim=None,shift=True):
""" """
# Internal functions so they can see globals # Internal functions so they can see globals
def shift_phase_freq(FID,phi,eps,extract=True): def shift_phase_freq(FID, phi, eps, extract=True):
sFID = np.exp(-1j*phi)*shift_FID(mrs,FID,eps) sFID = np.exp(-1j*phi)*shift_FID(mrs, FID, eps)
if extract: if extract:
sFID = extract_spectrum(mrs,sFID,ppmlim=ppmlim,shift=shift) sFID = extract_spectrum(mrs, sFID, ppmlim=ppmlim, shift=shift)
return sFID return sFID
def cf(p): def cf(p):
phi = p[0] #phase shift phi = p[0] # phase shift
eps = p[1] #freq shift eps = p[1] # freq shift
FID = shift_phase_freq(src_FID,phi,eps) FID = shift_phase_freq(src_FID, phi, eps)
target = extract_spectrum(mrs,tgt_FID,ppmlim=ppmlim,shift=shift) target = extract_spectrum(mrs, tgt_FID, ppmlim=ppmlim, shift=shift)
xx = np.linalg.norm(FID-target) xx = np.linalg.norm(FID-target)
return xx return xx
x0 = np.array([0,0]) x0 = np.array([0, 0])
res = minimize(cf, x0, method='Powell') res = minimize(cf, x0, method='Powell')
phi = res.x[0] phi = res.x[0]
eps = res.x[1] eps = res.x[1]
return shift_phase_freq(src_FID,phi,eps,extract=False),phi,eps return phi, eps
def align_FID_diff(mrs,src_FID0,src_FID1,tgt_FID,diffType = 'add',ppmlim=None,shift=True): def align_FID_diff(mrs,src_FID0,src_FID1,tgt_FID,diffType = 'add',ppmlim=None,shift=True):
""" """
...@@ -98,9 +102,10 @@ def align_FID_diff(mrs,src_FID0,src_FID1,tgt_FID,diffType = 'add',ppmlim=None,sh ...@@ -98,9 +102,10 @@ def align_FID_diff(mrs,src_FID0,src_FID1,tgt_FID,diffType = 'add',ppmlim=None,sh
return alignedFID0,phi,eps return alignedFID0,phi,eps
# The functions to call # The functions to call
# 1) For normal FIDs # 1) For normal FIDs
def phase_freq_align(FIDlist,bandwidth,centralFrequency,ppmlim=None,niter=2,verbose=False,shift=True,target=None): def phase_freq_align(FIDlist,bandwidth,centralFrequency,ppmlim=None,niter=2,apodize=10,verbose=False,shift=True,target=None):
""" """
Algorithm: Algorithm:
Average spectra Average spectra
...@@ -118,6 +123,7 @@ def phase_freq_align(FIDlist,bandwidth,centralFrequency,ppmlim=None,niter=2,verb ...@@ -118,6 +123,7 @@ def phase_freq_align(FIDlist,bandwidth,centralFrequency,ppmlim=None,niter=2,verb
centralFrequency : float (unit=Hz) centralFrequency : float (unit=Hz)
ppmlim : tuple ppmlim : tuple
niter : int niter : int
apodize : float (unit=Hz)
verbose : bool verbose : bool
shift : apply H20 shift to ppm limit shift : apply H20 shift to ppm limit
ref : reference data to align to ref : reference data to align to
...@@ -128,26 +134,41 @@ def phase_freq_align(FIDlist,bandwidth,centralFrequency,ppmlim=None,niter=2,verb ...@@ -128,26 +134,41 @@ def phase_freq_align(FIDlist,bandwidth,centralFrequency,ppmlim=None,niter=2,verb
""" """
all_FIDs = FIDlist.copy() all_FIDs = FIDlist.copy()
phiOut,epsOut = np.zeros(len(FIDlist)),np.zeros(len(FIDlist)) phiOut, epsOut = np.zeros(len(FIDlist)), np.zeros(len(FIDlist))
for iter in range(niter): for iter in range(niter):
if verbose: if verbose:
print(' ---- iteration {} ----\n'.format(iter)) print(' ---- iteration {} ----\n'.format(iter))
if target is None: if target is None:
target = get_target_FID(FIDlist,target='nearest_to_mean') target = get_target_FID(all_FIDs, target='nearest_to_mean')
MRSargs = {'FID':target,'bw':bandwidth,'cf':centralFrequency} MRSargs = {'FID': target, 'bw': bandwidth, 'cf': centralFrequency}
mrs = MRS(**MRSargs) mrs = MRS(**MRSargs)
for idx,FID in enumerate(all_FIDs): if apodize > 0:
target = apod(target, mrs.dwellTime, [apodize])
for idx, FID in enumerate(all_FIDs):
if verbose: if verbose:
print('... aligning FID number {}'.format(idx),end='\r') print(f'... aligning FID number {idx}\r')
all_FIDs[idx],phi,eps = align_FID(mrs,FID,target,ppmlim=ppmlim,shift=shift)
if apodize > 0:
FID_apod = apod(FID.copy(), mrs.dwellTime, [apodize])
else:
FID_apod = FID
phi, eps = align_FID(mrs,
FID_apod,
target,
ppmlim=ppmlim,
shift=shift)
all_FIDs[idx] = np.exp(-1j*phi) * shift_FID(mrs, FID, eps)
phiOut[idx] += phi phiOut[idx] += phi
epsOut[idx] += eps epsOut[idx] += eps
if verbose: if verbose:
print('\n') print('\n')
return all_FIDs,phiOut,epsOut return all_FIDs, phiOut, epsOut
# 2) To align spectra from different groups with optional processing applied. # 2) To align spectra from different groups with optional processing applied.
def phase_freq_align_diff(FIDlist0,FIDlist1,bandwidth,centralFrequency,diffType = 'add',ppmlim=None,shift=True,target=None): def phase_freq_align_diff(FIDlist0,FIDlist1,bandwidth,centralFrequency,diffType = 'add',ppmlim=None,shift=True,target=None):
......
...@@ -16,7 +16,7 @@ def apodize(FID,dwelltime,broadening,filter='exp'): ...@@ -16,7 +16,7 @@ def apodize(FID,dwelltime,broadening,filter='exp'):
Args: Args:
FID (ndarray): Time domain data FID (ndarray): Time domain data
dwelltime (float): dwelltime in seconds dwelltime (float): dwelltime in seconds
broadening (tuple,float): shift in Hz broadening (tuple,float): apodisation in Hz
filter (str,optional):'exp','l2g' filter (str,optional):'exp','l2g'
Returns: Returns:
......
...@@ -115,3 +115,90 @@ def add_subtract_report(inFID,inFID2,outFID,hdr,ppmlim=(0.2,4.2),function='Not s ...@@ -115,3 +115,90 @@ def add_subtract_report(inFID,inFID2,outFID,hdr,ppmlim=(0.2,4.2),function='Not s
return fig return fig
else: else:
return fig return fig
def generic_report(inFID,outFID,inHdr,outHdr,ppmlim = (0.2,4.2),html=None,function=''):
"""
Generate generic report
"""
from fsl_mrs.core import MRS
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from fsl_mrs.utils.preproc.reporting import plotStyles, plotAxesStyle
plotIn = MRS(FID=inFID, header=inHdr)
plotOut = MRS(FID=outFID, header=outHdr)
# Fetch line styles
lines, colors, _ = plotStyles()
# Make a new figure
fig = make_subplots(rows=1, cols=2, subplot_titles=['Spectra', 'FID'])
# Add lines to figure
trace1 = go.Scatter(x=plotIn.getAxes(ppmlim=ppmlim),
y=np.real(plotIn.getSpectrum(ppmlim=ppmlim)),
mode='lines',
name='Original',
line=lines['in'])
trace2 = go.Scatter(x=plotOut.getAxes(ppmlim=ppmlim),
y=np.real(plotOut.getSpectrum(ppmlim=ppmlim)),
mode='lines',
name='Shifted',
line=lines['out'])
fig.add_trace(trace1, row=1, col=1)
fig.add_trace(trace2, row=1, col=1)
# Add lines to figure
trace3 = go.Scatter(x=plotIn.getAxes(axis='time'),
y=np.real(plotIn.FID),
mode='lines',
name='Original',
line=lines['emph'])
trace4 = go.Scatter(x=plotOut.getAxes(axis='time'),
y=np.real(plotOut.FID),
mode='lines',
name='Shifted',
line=lines['diff'])
fig.add_trace(trace3, row=1, col=2)
fig.add_trace(trace4, row=1, col=2)
# Axes layout
plotAxesStyle(fig, ppmlim, title=f'{function} summary')
fig.layout.xaxis2.update(title_text='Time (s)')
fig.layout.yaxis2.update(zeroline=True,
zerolinewidth=1,
zerolinecolor='Gray',
showgrid=False,
showticklabels=False)
if html is not None:
from plotly.offline import plot
from fsl_mrs.utils.preproc.reporting import figgroup, singleReport
from datetime import datetime
import os.path as op
if op.isdir(html):
filename = 'report_' + datetime.now().strftime("%Y%m%d_%H%M%S%f")[:-3]+'.html'
htmlfile = op.join(html, filename)
elif op.isdir(op.dirname(html)) and op.splitext(html)[1] == '.html':
htmlfile = html
else:
raise ValueError('Report html path must be file or directory. ')
opName = function
timestr = datetime.now().strftime("%H:%M:%S")
datestr = datetime.now().strftime("%d/%m/%Y")
headerinfo = f'Report for {function}.\n' + \
f'Generated at {timestr} on {datestr}.'
# Figures
div = plot(fig, output_type='div', include_plotlyjs='cdn')
figurelist = [figgroup(fig=div,
name='',
foretext='',
afttext='')]
singleReport(htmlfile, opName, headerinfo, figurelist)
return fig
else:
return fig
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk> # Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# William Clarke <william.clarke@ndcn.ox.ac.uk> # William Clarke <william.clarke@ndcn.ox.ac.uk>
# #
# Copyright (C) 2019 University of Oxford # Copyright (C) 2019 University of Oxford
# SHBASECOPYRIGHT # SHBASECOPYRIGHT
import numpy as np import numpy as np
...@@ -13,13 +13,16 @@ from fsl_mrs.core import MRS ...@@ -13,13 +13,16 @@ from fsl_mrs.core import MRS
from fsl_mrs.utils.misc import extract_spectrum from fsl_mrs.utils.misc import extract_spectrum
from fsl_mrs.utils.preproc.shifting import pad from fsl_mrs.utils.preproc.shifting import pad
from fsl_mrs.utils.preproc.remove import hlsvd from fsl_mrs.utils.preproc.remove import hlsvd
def applyPhase(FID,phaseAngle):
def applyPhase(FID, phaseAngle):
""" """
Multiply FID by constant phase Multiply FID by constant phase
""" """
return FID * np.exp(1j*phaseAngle) return FID * np.exp(1j*phaseAngle)
def phaseCorrect(FID,bw,cf,ppmlim=(2.8,3.2),shift=True):
def phaseCorrect(FID, bw, cf, ppmlim=(2.8, 3.2), shift=True, no_hlsvd=False):
""" Phase correction based on the phase of a maximum point. """ Phase correction based on the phase of a maximum point.
HLSVD is used to remove peaks outside the limits to flatten baseline first. HLSVD is used to remove peaks outside the limits to flatten baseline first.
...@@ -30,18 +33,22 @@ def phaseCorrect(FID,bw,cf,ppmlim=(2.8,3.2),shift=True): ...@@ -30,18 +33,22 @@ def phaseCorrect(FID,bw,cf,ppmlim=(2.8,3.2),shift=True):
cf (float): central frequency in Hz cf (float): central frequency in Hz
ppmlim (tuple,optional) : Limit to this ppm range ppmlim (tuple,optional) : Limit to this ppm range
shift (bool,optional) : Apply H20 shft shift (bool,optional) : Apply H20 shft
no_hlsvd (bool,optional) : Disable hlsvd step
Returns: Returns:
FID (ndarray): Phase corrected FID FID (ndarray): Phase corrected FID
""" """
# Run HLSVD to remove peaks outside limits # Run HLSVD to remove peaks outside limits
try: if no_hlsvd:
fid_hlsvd = hlsvd(FID,1/bw,cf,(ppmlim[1]+0.5,ppmlim[1]+3.0),limitUnits='ppm+shift')
fid_hlsvd = hlsvd(fid_hlsvd,1/bw,cf,(ppmlim[0]-3.0,ppmlim[0]-0.5),limitUnits='ppm+shift')
except:
fid_hlsvd = FID fid_hlsvd = FID
print('Phasing HLSVD failed, proceeding to phaseing.') else:
#Find maximum of absolute spectrum in ppm limit try:
fid_hlsvd = hlsvd(FID,1/bw,cf,(ppmlim[1]+0.5,ppmlim[1]+3.0),limitUnits='ppm+shift')
fid_hlsvd = hlsvd(fid_hlsvd,1/bw,cf,(ppmlim[0]-3.0,ppmlim[0]-0.5),limitUnits='ppm+shift')
except:
fid_hlsvd = FID
print('HLSVD in phaseCorrect failed, proceeding to phasing.')
# Find maximum of absolute spectrum in ppm limit
padFID = pad(fid_hlsvd,FID.size*3) padFID = pad(fid_hlsvd,FID.size*3)
MRSargs = {'FID':padFID,'bw':bw,'cf':cf} MRSargs = {'FID':padFID,'bw':bw,'cf':cf}
mrs = MRS(**MRSargs) mrs = MRS(**MRSargs)
...@@ -52,9 +59,10 @@ def phaseCorrect(FID,bw,cf,ppmlim=(2.8,3.2),shift=True): ...@@ -52,9 +59,10 @@ def phaseCorrect(FID,bw,cf,ppmlim=(2.8,3.2),shift=True):
return applyPhase(FID,phaseAngle),phaseAngle,int(np.round(maxIndex/4)) return applyPhase(FID,phaseAngle),phaseAngle,int(np.round(maxIndex/4))
def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2),html=None): def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2),html=None):
""" """
Generate report Generate report for phaseCorrect
""" """
# from matplotlib import pyplot as plt # from matplotlib import pyplot as plt
from fsl_mrs.core import MRS from fsl_mrs.core import MRS
...@@ -103,7 +111,7 @@ def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2),html=None): ...@@ -103,7 +111,7 @@ def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2),html=None):
# Axes layout # Axes layout
plotAxesStyle(fig,widelimit,title = 'Phase correction summary') plotAxesStyle(fig,widelimit,title = 'Phase correction summary')
# Axea # Axes
if html is not None: if html is not None:
from plotly.offline import plot from plotly.offline import plot
from fsl_mrs.utils.preproc.reporting import figgroup, singleReport from fsl_mrs.utils.preproc.reporting import figgroup, singleReport
...@@ -134,25 +142,3 @@ def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2),html=None): ...@@ -134,25 +142,3 @@ def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2),html=None):
return fig return fig
else: else:
return fig return fig
# matplotlib version of report
# def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2)):
# from matplotlib import pyplot as plt
# from fsl_mrs.core import MRS
# from fsl_mrs.utils.plotting import styleSpectrumAxes
# toMRSobj = lambda fid : MRS(FID=fid,header=hdr)
# plotIn = toMRSobj(inFID)
# plotOut = toMRSobj(outFID)
# widelimit = (0,6)
# fig = plt.figure(figsize=(10,10))
# plt.plot(plotIn.getAxes(ppmlim=widelimit),np.real(plotIn.getSpectrum(ppmlim=widelimit)),'k',label='Unphased', linewidth=2)
# plt.plot(plotIn.getAxes(ppmlim=ppmlim),np.real(plotIn.getSpectrum(ppmlim=ppmlim)),'r',label='search region', linewidth=2)
# plt.plot(plotIn.getAxes(ppmlim=ppmlim)[position],np.real(plotIn.getSpectrum(ppmlim=ppmlim))[position],'rx',label='max point', linewidth=2)
# plt.plot(plotOut.getAxes(ppmlim=widelimit),np.real(plotOut.getSpectrum(ppmlim=widelimit)),'b--',label='Phased', linewidth=2)
# styleSpectrumAxes(ax=plt.gca())
# plt.legend()
# plt.rcParams.update({'font.size': 12})
# plt.show()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment