Commit 61089ace authored by William Clarke's avatar William Clarke
Browse files

Translation layer for NIFTI MRS processing complete.

parent f4632295
......@@ -20,6 +20,11 @@ class NIFTIMRS_DimDoesntExist(Exception):
class NIFTI_MRS(Image):
"""Load NIFTI MRS format data. Derived from nibabel's Nifti2Image."""
def __init__(self, *args, **kwargs):
# If generated from np array include conjugation
# to make sure storage is right-handed
if isinstance(args[0], np.ndarray):
args = list(args)
args[0] = args[0].conj()
super().__init__(*args, **kwargs)
# Extract key parameters from the header extension
......
#!/usr/bin/env python
# general.py - General preprocessing functions
#
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# William Clarke <william.clarke@ndcn.ox.ac.uk>
#
# Copyright (C) 2019 University of Oxford
# Copyright (C) 2019 University of Oxford
# SHBASECOPYRIGHT
import numpy as np
from dataclasses import dataclass
@dataclass
class datacontainer:
'''Class for keeping track of data and reference data together.'''
......@@ -22,7 +20,8 @@ class datacontainer:
refheader: dict = None
reffilename: str = None
def get_target_FID(FIDlist,target='mean'):
def get_target_FID(FIDlist, target='mean'):
"""
target can be 'mean' or 'first' or 'nearest_to_mean' or 'median'
"""
......@@ -32,59 +31,73 @@ def get_target_FID(FIDlist,target='mean'):
return FIDlist[0].copy()
elif target == 'nearest_to_mean':
avg = sum(FIDlist) / len(FIDlist)
d = [np.linalg.norm(fid-avg) for fid in FIDlist]
d = [np.linalg.norm(fid - avg) for fid in FIDlist]
return FIDlist[np.argmin(d)].copy()
elif target == 'median':
return np.median(np.real(np.asarray(FIDlist)),axis=0)+1j*np.median(np.imag(np.asarray(FIDlist)),axis=0)
return np.median(np.real(np.asarray(FIDlist)), axis=0)\
+ 1j * np.median(np.imag(np.asarray(FIDlist)), axis=0)
else:
raise(Exception('Unknown target type {}'.format(target)))
def subtract(FID1,FID2):
def subtract(FID1, FID2):
""" Subtract FID2 from FID1."""
return (FID1-FID2)/2.0
return (FID1 - FID2) / 2.0
def add(FID1,FID2):
""" Add FID2 to FID1."""
return (FID1+FID2)/2.0
def add_subtract_report(inFID,inFID2,outFID,hdr,ppmlim=(0.2,4.2),function='Not specified',html=None):
def add(FID1, FID2):
""" Add FID2 to FID1."""
return (FID1 + FID2) / 2.0
def add_subtract_report(inFID,
inFID2,
outFID,
bw,
cf,
nucleus='1H',
ppmlim=(0.2, 4.2),
function='Not specified',
html=None):
"""
Generate report
"""
# from matplotlib import pyplot as plt
from fsl_mrs.core import MRS
import plotly.graph_objects as go
from fsl_mrs.utils.preproc.reporting import plotStyles,plotAxesStyle
import plotly.graph_objects as go
from fsl_mrs.utils.preproc.reporting import plotStyles, plotAxesStyle
# Turn input FIDs into mrs objects
toMRSobj = lambda fid : MRS(FID=fid,header=hdr)
def toMRSobj(fid):
return MRS(FID=fid, cf=cf, bw=bw, nucleus=nucleus)
plotIn = toMRSobj(inFID)
plotIn2 = toMRSobj(inFID2)
plotOut = toMRSobj(outFID)
# Fetch line styles
lines,colors,_ = plotStyles()
lines, colors, _ = plotStyles()
# Make a new figure
fig = go.Figure()
# Add lines to figure
def addline(fig,mrs,lim,name,linestyle):
def addline(fig, mrs, lim, name, linestyle):
trace = go.Scatter(x=mrs.getAxes(ppmlim=lim),
y=np.real(mrs.get_spec(ppmlim=lim)),
mode='lines',
name=name,
line=linestyle)
return fig.add_trace(trace)
fig = addline(fig,plotIn,ppmlim,'FID1',lines['in'])
fig = addline(fig,plotIn2,ppmlim,'FID2',lines['out'])
fig = addline(fig,plotOut,ppmlim,'Result',lines['diff'])
y=np.real(mrs.get_spec(ppmlim=lim)),
mode='lines',
name=name,
line=linestyle)
return fig.add_trace(trace)
fig = addline(fig, plotIn, ppmlim, 'FID1', lines['in'])
fig = addline(fig, plotIn2, ppmlim, 'FID2', lines['out'])
fig = addline(fig, plotOut, ppmlim, 'Result', lines['diff'])
# Axes layout
plotAxesStyle(fig,ppmlim,title = f'{function} summary')
# Axea
plotAxesStyle(fig, ppmlim, title=f'{function} summary')
# Axes
if html is not None:
from plotly.offline import plot
from fsl_mrs.utils.preproc.reporting import figgroup, singleReport
......@@ -92,32 +105,38 @@ def add_subtract_report(inFID,inFID2,outFID,hdr,ppmlim=(0.2,4.2),function='Not s
import os.path as op
if op.isdir(html):
filename = 'report_' + datetime.now().strftime("%Y%m%d_%H%M%S%f")[:-3]+'.html'
htmlfile=op.join(html,filename)
elif op.isdir(op.dirname(html)) and op.splitext(html)[1]=='.html':
filename = 'report_' + datetime.now().strftime("%Y%m%d_%H%M%S%f")[:-3] + '.html'
htmlfile = op.join(html, filename)
elif op.isdir(op.dirname(html)) and op.splitext(html)[1] == '.html':
htmlfile = html
else:
raise ValueError('Report html path must be file or directory. ')
opName = function
timestr = datetime.now().strftime("%H:%M:%S")
datestr = datetime.now().strftime("%d/%m/%Y")
headerinfo = f'Report for fsl_mrs.utils.preproc.general.{function}.\n'\
+ f'Generated at {timestr} on {datestr}.'
+ f'Generated at {timestr} on {datestr}.'
# Figures
div = plot(fig, output_type='div',include_plotlyjs='cdn')
figurelist = [figgroup(fig = div,
name= '',
foretext= f'Report for {function}.',
afttext= f'')]
div = plot(fig, output_type='div', include_plotlyjs='cdn')
figurelist = [figgroup(fig=div,
name='',
foretext=f'Report for {function}.',
afttext='')]
singleReport(htmlfile,opName,headerinfo,figurelist)
singleReport(htmlfile, opName, headerinfo, figurelist)
return fig
else:
return fig
def generic_report(inFID,outFID,inHdr,outHdr,ppmlim = (0.2,4.2),html=None,function=''):
def generic_report(inFID,
outFID,
inHdr,
outHdr,
ppmlim=(0.2, 4.2),
html=None,
function=''):
"""
Generate generic report
"""
......@@ -179,7 +198,7 @@ def generic_report(inFID,outFID,inHdr,outHdr,ppmlim = (0.2,4.2),html=None,functi
import os.path as op
if op.isdir(html):
filename = 'report_' + datetime.now().strftime("%Y%m%d_%H%M%S%f")[:-3]+'.html'
filename = 'report_' + datetime.now().strftime("%Y%m%d_%H%M%S%f")[:-3] + '.html'
htmlfile = op.join(html, filename)
elif op.isdir(op.dirname(html)) and op.splitext(html)[1] == '.html':
htmlfile = html
......
......@@ -509,18 +509,223 @@ def remove_unlike(data, ppmlim=None, sdlimit=1.96, niter=2, report=None):
goodFIDs = np.asarray(goodFIDs).T
goodFIDs = goodFIDs.reshape([1, 1, 1] + list(goodFIDs.shape))
# Conjugation here as it doesn't use the __setitem__ method
good_out = NIFTI_MRS(
goodFIDs.conj(),
goodFIDs,
header=data.header)
if len(badFIDs) > 0:
badFIDs = np.asarray(badFIDs).T
badFIDs = badFIDs.reshape([1, 1, 1] + list(badFIDs.shape))
bad_out = NIFTI_MRS(
badFIDs.conj(),
badFIDs,
header=data.header)
else:
bad_out = None
return good_out, bad_out
def phase_correct(data, ppmlim, hlsvd=True, report=None, report_all=False):
'''Zero-order phase correct based on peak maximum
:param NIFTI_MRS data: Data to truncate or pad
:param float ppmlim: Search for peak between limits
:param bool hlsvd: Use HLSVD to remove peaks outside the ppmlim
:param report: Provide output location as path to generate report
:param report_all: True to output all indicies
:return: Phased data in NIFTI_MRS format.
'''
phs_obj = data.copy()
for dd, idx in data.iterate_over_dims(iterate_over_space=True):
phs_obj[idx], _, pos = preproc.phaseCorrect(
dd,
data.bandwidth,
data.spectrometer_frequency[0],
nucleus=data.nucleus[0],
ppmlim=ppmlim,
use_hlsvd=hlsvd)
if report and (report_all or first_index(idx)):
from fsl_mrs.utils.preproc.phasing import phaseCorrect_report
phaseCorrect_report(dd,
phs_obj[idx],
pos,
data.bandwidth,
data.spectrometer_frequency[0],
nucleus=data.nucleus[0],
ppmlim=ppmlim,
html=report)
return phs_obj
def apply_fixed_phase(data, p0, p1=0.0, report=None, report_all=False):
'''Apply fixed phase correction
:param NIFTI_MRS data: Data to truncate or pad
:param float p0: Zero order phase correction in degrees
:param float p0: First order phase correction in seconds
:param report: Provide output location as path to generate report
:param report_all: True to output all indicies
:return: Phased data in NIFTI_MRS format.
'''
phs_obj = data.copy()
for dd, idx in data.iterate_over_dims(iterate_over_space=True):
phs_obj[idx] = preproc.applyPhase(dd,
p0 * (np.pi / 180.0))
if p1 != 0.0:
phs_obj[idx], _ = preproc.timeshift(
phs_obj[idx],
data.dwelltime,
p1,
p1,
samples=data.shape[3])
if report and (report_all or first_index(idx)):
from fsl_mrs.utils.preproc.general import generic_report
original_hdr = {'bandwidth': data.bandwidth,
'centralFrequency': data.spectrometer_frequency[0],
'ResonantNucleus': data.nucleus[0]}
generic_report(dd,
phs_obj[idx],
original_hdr,
original_hdr,
ppmlim=(0.2, 4.2),
html=report,
function='fixed phase')
return phs_obj
def subtract(data0, data1=None, dim=None, report=None, report_all=False):
'''Either subtract data1 from data0 or subtract index 1 from
index 0 along specified dimension
:param NIFTI_MRS data: Data to truncate or pad
:param data1: If specified data1 will be subtracted from data0
:param dim: If specified index 1 will be subtracted from 0 across this dimension.
:param report: Provide output location as path to generate report
:param report_all: True to output all indicies
:return: Subtracted data in NIFTI_MRS format.
'''
if dim is not None:
# Check dim is of correct size
if data0.shape[data0.dim_tags.index(dim)] != 2:
raise DimensionsDoNotMatch('Subtraction dimension must be of length 2.'
f' Currently {data0.shape[data0.dim_tags.index(dim)]}')
sub_ob = data0.copy(remove_dim=dim)
for dd, idx in data0.iterate_over_dims(dim=dim,
iterate_over_space=True,
reduce_dim_index=True):
sub_ob[idx] = preproc.subtract(dd.T[0], dd.T[1])
if report and (report_all or first_index(idx)):
from fsl_mrs.utils.preproc.general import generic_report
original_hdr = {'bandwidth': data0.bandwidth,
'centralFrequency': data0.spectrometer_frequency[0],
'ResonantNucleus': data0.nucleus[0]}
generic_report(dd,
sub_ob[idx],
original_hdr,
original_hdr,
ppmlim=(0.2, 4.2),
html=report,
function='subtract')
elif data1 is not None:
sub_ob = data0.copy()
sub_ob[:] = (data0[:] - data1[:]) / 2
else:
raise ValueError('One of data1 or dim arguments must not be None.')
return sub_ob
def add(data0, data1=None, dim=None, report=None, report_all=False):
'''Either add data1 to data0 or add index 1 to
index 0 along specified dimension
:param NIFTI_MRS data: Data to truncate or pad
:param data1: If specified data1 will be added to data0
:param dim: If specified index 1 will be added to 0 across this dimension.
:param report: Provide output location as path to generate report
:param report_all: True to output all indicies
:return: Subtracted data in NIFTI_MRS format.
'''
if dim is not None:
# Check dim is of correct size
if data0.shape[data0.dim_tags.index(dim)] != 2:
raise DimensionsDoNotMatch('Addition dimension must be of length 2.'
f' Currently {data0.shape[data0.dim_tags.index(dim)]}')
add_ob = data0.copy(remove_dim=dim)
for dd, idx in data0.iterate_over_dims(dim=dim,
iterate_over_space=True,
reduce_dim_index=True):
add_ob[idx] = preproc.add(dd.T[0], dd.T[1])
if report and (report_all or first_index(idx)):
from fsl_mrs.utils.preproc.general import generic_report
original_hdr = {'bandwidth': data0.bandwidth,
'centralFrequency': data0.spectrometer_frequency[0],
'ResonantNucleus': data0.nucleus[0]}
generic_report(dd,
add_ob[idx],
original_hdr,
original_hdr,
ppmlim=(0.2, 4.2),
html=report,
function='add')
elif data1 is not None:
add_ob = data0.copy()
add_ob[:] = (data0[:] + data1[:]) / 2
else:
raise ValueError('One of data1 or dim arguments must not be None.')
return add_ob
def conjugate(data, report=None, report_all=False):
'''Conjugate the data
:param NIFTI_MRS data: Data to truncate or pad
:param report: Provide output location as path to generate report
:param report_all: True to output all indicies
:return: Conjugated data in NIFTI_MRS format.
'''
conj_data = data.copy()
conj_data[:] = conj_data[:].conj()
if report:
for dd, idx in data.iterate_over_dims(iterate_over_space=True):
if report_all or first_index(idx):
from fsl_mrs.utils.preproc.general import generic_report
original_hdr = {'bandwidth': data.bandwidth,
'centralFrequency': data.spectrometer_frequency[0],
'ResonantNucleus': data.nucleus[0]}
generic_report(dd,
conj_data[idx],
original_hdr,
original_hdr,
ppmlim=(0.2, 4.2),
html=report,
function='conjugate')
return conj_data
......@@ -10,7 +10,7 @@
import numpy as np
from fsl_mrs.core import MRS
from fsl_mrs.utils.misc import extract_spectrum
from fsl_mrs.utils.misc import extract_spectrum, checkCFUnits
from fsl_mrs.utils.preproc.shifting import pad
from fsl_mrs.utils.preproc.remove import hlsvd
......@@ -19,10 +19,10 @@ def applyPhase(FID, phaseAngle):
"""
Multiply FID by constant phase
"""
return FID * np.exp(1j*phaseAngle)
return FID * np.exp(1j * phaseAngle)
def phaseCorrect(FID, bw, cf, ppmlim=(2.8, 3.2), shift=True, hlsvd=False):
def phaseCorrect(FID, bw, cf, nucleus='1H', ppmlim=(2.8, 3.2), shift=True, use_hlsvd=False):
""" Phase correction based on the phase of a maximum point.
HLSVD is used to remove peaks outside the limits to flatten baseline first.
......@@ -33,70 +33,84 @@ def phaseCorrect(FID, bw, cf, ppmlim=(2.8, 3.2), shift=True, hlsvd=False):
cf (float): central frequency in Hz
ppmlim (tuple,optional) : Limit to this ppm range
shift (bool,optional) : Apply H20 shft
hlsvd (bool,optional) : Enable hlsvd step
use_hlsvd (bool,optional) : Enable hlsvd step
Returns:
FID (ndarray): Phase corrected FID
phaseAngle (double): shift in radians
index (int): Index of phased point
"""
if hlsvd:
cf = checkCFUnits(cf, units='Hz')
if use_hlsvd:
# Run HLSVD to remove peaks outside limits
try:
fid_hlsvd = hlsvd(FID,1/bw,cf,(ppmlim[1]+0.5,ppmlim[1]+3.0),limitUnits='ppm+shift')
fid_hlsvd = hlsvd(fid_hlsvd,1/bw,cf,(ppmlim[0]-3.0,ppmlim[0]-0.5),limitUnits='ppm+shift')
except:
fid_hlsvd = hlsvd(FID, 1 / bw, cf, (ppmlim[1] + 0.5, ppmlim[1] + 3.0), limitUnits='ppm+shift')
fid_hlsvd = hlsvd(fid_hlsvd, 1 / bw, cf, (ppmlim[0] - 3.0, ppmlim[0] - 0.5), limitUnits='ppm+shift')
except Exception:
fid_hlsvd = FID
print('HLSVD in phaseCorrect failed, proceeding to phasing.')
else:
fid_hlsvd = FID
# Find maximum of absolute spectrum in ppm limit
padFID = pad(fid_hlsvd,FID.size*3)
MRSargs = {'FID':padFID,'bw':bw,'cf':cf}
padFID = pad(fid_hlsvd, FID.size * 3)
MRSargs = {'FID': padFID,
'bw': bw,
'cf': cf,
'nucleus': nucleus}
mrs = MRS(**MRSargs)
spec = extract_spectrum(mrs,padFID,ppmlim=ppmlim,shift=shift)
spec = extract_spectrum(mrs, padFID, ppmlim=ppmlim, shift=shift)
maxIndex = np.argmax(np.abs(spec))
phaseAngle = -np.angle(spec[maxIndex])
return applyPhase(FID,phaseAngle),phaseAngle,int(np.round(maxIndex/4))
return applyPhase(FID, phaseAngle), phaseAngle, int(np.round(maxIndex / 4))
def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2),html=None):
def phaseCorrect_report(inFID,
outFID,
position,
bw,
cf,
nucleus='1H',
ppmlim=(2.8, 3.2),
html=None):
"""
Generate report for phaseCorrect
"""
# from matplotlib import pyplot as plt
from fsl_mrs.core import MRS
import plotly.graph_objects as go
from fsl_mrs.utils.preproc.reporting import plotStyles,plotAxesStyle
import plotly.graph_objects as go
from fsl_mrs.utils.preproc.reporting import plotStyles, plotAxesStyle
# Turn input FIDs into mrs objects
toMRSobj = lambda fid : MRS(FID=fid,header=hdr)
def toMRSobj(fid):
return MRS(FID=fid, cf=cf, bw=bw, nucleus=nucleus)
plotIn = toMRSobj(inFID)
plotOut = toMRSobj(outFID)
widelimit = (0,6)
widelimit = (0, 6)
# Fetch line styles
lines,colors,_ = plotStyles()
lines, colors, _ = plotStyles()
# Make a new figure
fig = go.Figure()
# Add lines to figure
def addline(fig,mrs,lim,name,linestyle):
def addline(fig, mrs, lim, name, linestyle):
trace = go.Scatter(x=mrs.getAxes(ppmlim=lim),
y=np.real(mrs.get_spec(ppmlim=lim)),
mode='lines',
name=name,
line=linestyle)
return fig.add_trace(trace)
fig = addline(fig,plotIn,widelimit,'Unphased',lines['in'])
fig = addline(fig,plotIn,ppmlim,'Search region',lines['emph'])
y=np.real(mrs.get_spec(ppmlim=lim)),
mode='lines',
name=name,
line=linestyle)
return fig.add_trace(trace)
fig = addline(fig, plotIn, widelimit, 'Unphased', lines['in'])
fig = addline(fig, plotIn, ppmlim, 'Search region', lines['emph'])
if position is None:
# re-estimate here.
......@@ -105,17 +119,17 @@ def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2),html=None):
axis = [plotIn.getAxes(ppmlim=ppmlim)[position]]
y_data = [np.real(plotIn.get_spec(ppmlim=ppmlim))[position]]
trace = go.Scatter(x=axis, y=y_data,
mode='markers',
name='max point',
marker=dict(color=colors['emph'],symbol='x',size=8))
mode='markers',
name='max point',
marker=dict(color=colors['emph'], symbol='x', size=8))
fig.add_trace(trace)
fig = addline(fig,plotOut,widelimit,'Phased',lines['out'])
fig = addline(fig, plotOut, widelimit, 'Phased', lines['out'])
# Axes layout
plotAxesStyle(fig,widelimit,title = 'Phase correction summary')
# Axes
plotAxesStyle(fig, widelimit, title='Phase correction summary')
# Axes
if html is not None:
from plotly.offline import plot
from fsl_mrs.utils.preproc.reporting import figgroup, singleReport
......@@ -123,26 +137,27 @@ def phaseCorrect_report(inFID,outFID,hdr,position,ppmlim=(2.8,3.2),html=None):
import os.path as op
if op.isdir(html):
filename = 'report_' + datetime.now().strftime("%Y%m%d_%H%M%S%f")[:-3]+'.html'
htmlfile=op.join(html,filename)
elif op.isdir(op.dirname(html)) and op.splitext(html)[1]=='.html':
filename = 'report_' + datetime.now().strftime("%Y%m%d_%H%M%S%f")[:-3] + '.html'
htmlfile = op.join(html, filename)
elif op.isdir(op.dirname(html)) and op.splitext(html</