Commit 796eae2c authored by William Clarke's avatar William Clarke
Browse files

Rework of mrs_io for new format. test_scripts_proc passes.

parent c47449b4
......@@ -2,5 +2,5 @@ from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
from fsl_mrs.core import MRS
from fsl_mrs.core import MRSI
\ No newline at end of file
# from fsl_mrs.core import MRS
# from fsl_mrs.core import MRSI
\ No newline at end of file
......@@ -13,7 +13,7 @@ from fsl_mrs.core import MRS
from fsl_mrs.utils import mrs_io, misc
import matplotlib.pyplot as plt
import nibabel as nib
from fsl_mrs.utils.mrs_io.fsl_io import saveNIFTI, readNIFTI
from fsl_mrs.utils.mrs_io import fsl_io
class MRSI(object):
......@@ -288,7 +288,7 @@ class MRSI(object):
data[data > 1e10] = 0
if nt == self.FID_points:
saveNIFTI(file_path_name, data, self.header)
fsl_io.saveNIFTI(file_path_name, data, self.header)
else:
img = nib.Nifti1Image(data, self.header['nifti'].affine)
nib.save(img, file_path_name)
......@@ -304,7 +304,8 @@ class MRSI(object):
""" Load MRSI data directly from files """
data, hdr = mrs_io.read_FID(data_file)
if mask_file is not None:
mask, _ = readNIFTI(mask_file)
nib_img = nib.load(mask_file)
mask = np.asanyarray(nib_img.dataobj)
else:
mask = None
......
......@@ -6,12 +6,14 @@
# Copyright (C) 2021 University of Oxford
# SHBASECOPYRIGHT
import nibabel as nib
import numpy as np
import json
from fsl.data.image import Image
from fsl_mrs.core import MRS, MRSI
import fsl.utils.path as fslpath
from nibabel.nifti2 import Nifti2Header
from nibabel.nifti1 import Nifti1Extension
from fsl_mrs.utils.misc import checkCFUnits
class NIFTIMRS_DimDoesntExist(Exception):
......@@ -25,12 +27,51 @@ class NotNIFTI_MRS(Exception):
def is_nifti_mrs(file_path):
'''Check that a file is of the NIFTI-MRS format type.'''
try:
obj = NIFTI_MRS(file_path)
if float(obj.mrs_nifti_version) < 0.2:
raise NotNIFTI_MRS('NIFTI-MRS > v0.2 required.')
NIFTI_MRS(file_path)
return True
except fslpath.PathError:
raise NotNIFTI_MRS("File isn't NIFTI-MRS, wrong extension type.")
raise NotNIFTI_MRS("File isn't NIFTI-MRS, wrong extension type.")
def gen_new_nifti_mrs(data, dwelltime, spec_freq, nucleus='1H', affine=None, dim_tags=[None, None, None]):
'''Generate a NIFTI_MRS object from a np array and header info.
:param np.ndarray data: FID (time-domain) data. Must be atleast 4D.
:param float dwelltime: Dwelltime of FID data in seconds.
:param float spec_freq: Spectrometer (or central) frequency in MHz.
:param str nucleus: Nucleus string, defaults to '1H'
:param np.ndarray affine: Optional 4x4 position affine.
:param [str] affine: List of dimension tags.
:return: NIFTI_MRS object
'''
if not np.iscomplex(data).all():
raise ValueError('data must be complex')
if data.ndim < 4 or data.ndim > 7:
raise ValueError(f'data must between 4 and 7 dimensions, currently has {data.ndim}')
header = Nifti2Header()
header['pixdim'][4] = dwelltime
hdr_dict = {'SpectrometerFrequency': [checkCFUnits(spec_freq, units='MHz'), ],
'ResonantNucleus': [nucleus, ]}
for idx, dt in enumerate(dim_tags):
if dt is not None:
if (idx + 4) > data.ndim:
raise ValueError('Too many dimension tags passed.')
hdr_dict[f'dim_{idx+5}'] = dt
json_s = json.dumps(hdr_dict)
extension = Nifti1Extension(44, json_s.encode('UTF-8'))
header.extensions.append(extension)
header.set_qform(affine)
header.set_sform(affine)
header['intent_name'] = 'mrs_v0_2'.encode()
return NIFTI_MRS(data, header=header)
class NIFTI_MRS(Image):
......@@ -43,6 +84,19 @@ class NIFTI_MRS(Image):
args[0] = args[0].conj()
super().__init__(*args, **kwargs)
# Check that file meets minimum requirements
if float(self.mrs_nifti_version) < 0.2:
raise NotNIFTI_MRS('NIFTI-MRS > V0.2 required.')
if 44 not in self.header.extensions.get_codes():
raise NotNIFTI_MRS('NIFTI-MRS must have a header extension.')
try:
self.nucleus
self.spectrometer_frequency
except KeyError:
raise NotNIFTI_MRS('NIFTI-MRS header extension must have nucleus and spectrometerFrequency keys.')
# Extract key parameters from the header extension
self._set_dim_tags()
......@@ -113,7 +167,7 @@ class NIFTI_MRS(Image):
def hdr_ext(self, hdr_dict):
'''Update MRS JSON header extension from python dict'''
json_s = json.dumps(hdr_dict)
extension = nib.nifti1.Nifti1Extension(44, json_s.encode('UTF-8'))
extension = Nifti1Extension(44, json_s.encode('UTF-8'))
self.header.extensions.clear()
self.header.extensions.append(extension)
......
......@@ -82,8 +82,9 @@ def main():
align_group = alignparser.add_argument_group('Align arguments')
align_group.add_argument('--file', type=str, required=True,
help='List of files to align')
align_group.add_argument('--dim', type=str, required=True,
help='NIFTI-MRS dimension tag to align across')
align_group.add_argument('--dim', type=str, default='DIM_DYN',
help='NIFTI-MRS dimension tag to align across.'
'Default = DIM_DYN')
align_group.add_argument('--ppm', type=float, nargs=2,
metavar='<lower-limit upper-limit>',
default=(0.2, 4.2),
......@@ -102,7 +103,7 @@ def main():
alignD_group = alignDparser.add_argument_group('Align subspec arguments')
alignD_group.add_argument('--file', type=str, required=True,
help='Subspectra 1 - List of files to align')
alignD_group.add_argument('--dim', type=str, required=True,
alignD_group.add_argument('--dim', type=str, default='DIM_DYN',
help='NIFTI-MRS dimension tag to align across')
alignD_group.add_argument('--dim_diff', type=str, default='DIM_EDIT',
help='NIFTI-MRS dimension tag to difference across')
......@@ -310,8 +311,7 @@ def main():
# Handle data loading
dataList = loadData(datafiles,
refdatafiles=reffiles,
conjugate=args.conjugate)
refdatafile=reffiles)
# Create output folder if required
if not op.isdir(args.output):
......@@ -363,8 +363,8 @@ def add_common_args(p):
optional.add_argument('--allreports', action="store_true",
help='Generate reports for all inputs.'
' Overrides arguments to reportIndicies.')
optional.add_argument('--conjugate', action="store_true",
help='apply conjugate to FID')
# optional.add_argument('--conjugate', action="store_true",
# help='apply conjugate to FID')
optional.add_argument('--filename', type=str, metavar='<str>',
help='Override output file name.')
optional.add_argument('--verbose', action="store_true",
......@@ -405,14 +405,13 @@ def loadData(datafile, refdatafile=None):
' spec2nii.')
if refdatafile:
loaded_data = datacontainer(NIFTI_MRS(datafile),
datafile,
NIFTI_MRS(datafile))
else:
loaded_data = datacontainer(NIFTI_MRS(datafile),
datafile,
NIFTI_MRS(refdatafile),
refdatafile)
else:
loaded_data = datacontainer(NIFTI_MRS(datafile),
datafile)
return loaded_data
......@@ -641,9 +640,9 @@ def add(dataobj, args):
def conj(dataobj, args):
conjugated = preproc.apply_fixed_phase(dataobj.data,
report=args['generateReports'],
report_all=args['allreports'])
conjugated = preproc.conjugate(dataobj.data,
report=args['generateReports'],
report_all=args['allreports'])
return datacontainer(conjugated, dataobj.datafilename)
......
This diff is collapsed.
from fsl_mrs.utils.mrs_io.main import read_FID, read_basis, check_datatype
\ No newline at end of file
from fsl_mrs.utils.mrs_io.main import read_FID, read_basis, check_datatype
......@@ -3,27 +3,25 @@
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# Will Clarke <william.clarke@ndcn.ox.ac.uk>
#
# Copyright (C) 2020 University of Oxford
# Copyright (C) 2020 University of Oxford
# SHBASECOPYRIGHT
import numpy as np
import json
import nibabel as nib
import sys, os, glob
import os
import glob
import re
import scipy.signal as ss
from fsl_mrs.core.NIFTI_MRS import gen_new_nifti_mrs
# NIFTI I/O
def readNIFTI(datafile,squeezeSVS=True):
""" Read nifti format file.
Args:
datafile (str)
squeezeSVS (optional,bool)
Returns:
data
header
def readNIFTI(datafile):
""" Read old (pre NIFTI-MRS) nifti format file.
:param str datafile: Path to file
:return: NIFTI-MRS
"""
data_hdr = nib.load(datafile)
data = np.asanyarray(data_hdr.dataobj)
......@@ -35,88 +33,71 @@ def readNIFTI(datafile,squeezeSVS=True):
# Reciever bandwidth and dwelltime can either be fetched from the nifti header or the json
# central frequency is currently only sotred in the json.
if jsonParams is None:
dwell = data_hdr.header['pixdim'][4]
bw = 1/dwell
header ={'nifti':data_hdr,'centralFrequency':None,'dwelltime':dwell,'bandwidth':bw}
else:
header ={'nifti':data_hdr,
'json':jsonParams,
'centralFrequency':jsonParams['ImagingFrequency'],
'dwelltime':jsonParams['Dwelltime'],
'bandwidth':1/jsonParams['Dwelltime']}
if "EchoTime" in jsonParams:
header['TE'] = jsonParams['EchoTime']
# If there is only one FID (SVS) and squeezeSVS is true then
# remove singleton dimensions
numVoxels = np.product(data.shape[0:3])
if numVoxels==1 and squeezeSVS:
data = np.squeeze(data)
return data,header
def saveNIFTI(datafile,data,header,affine=None):
if 'nifti' not in header and affine is None:
raise ValueError('To save a nifti file the header must contain a nifti field or an affine must be specifed')
if affine is not None:
affineToUse = affine
raise Exception('Unable to load files without JSON sidecar')
else:
affineToUse = header['nifti'].affine
if data.ndim == 1:
data = data.reshape((1,1,1,data.size))
if 'Dwelltime' in jsonParams:
dwelltime = jsonParams['Dwelltime']
else:
dwelltime = data_hdr.header['pixdim'][4]
spec_freq = jsonParams['ImagingFrequency']
if 'ResonantNucleus' in jsonParams:
nucleus = jsonParams['ResonantNucleus']
else:
nucleus = '1H'
img = nib.Nifti2Image(data,affine=affineToUse)
return gen_new_nifti_mrs(data, dwelltime, spec_freq, nucleus=nucleus, affine=data_hdr.affine)
# insert the correct dwell time into the nifti file, it will then plot in fsleyes with the correct faxis
img.header['pixdim'][4] = header['dwelltime']
nib.save(img,datafile)
if 'json' in header:
writeJSONSidecar(datafile,header['json'])
elif ('centralFrequency' in header) and (header['centralFrequency'] is not None): # Store the essential parameters
jsonheader ={'ImagingFrequency':header['centralFrequency'],
'Dwelltime':header['dwelltime']}
writeJSONSidecar(datafile,jsonheader)
return
def saveNIFTI(filepath, data, header, affine=None):
'''Provide translation layer from old interface to new NIFTI_MRS'''
gen_new_nifti_mrs(data,
1 / header['bandwidth'],
header['centralFrequency'],
nucleus=header['ResonantNucleus'],
affine=affine).save(filepath)
# JSON sidecar I/O
def readJSONSidecar(niftiFile):
# Determine if there is a json file
rePattern = re.compile(r'\.nii(\.gz)?')
jsonFile = rePattern.sub('.json', niftiFile)
if os.path.isfile(jsonFile):
if os.path.isfile(jsonFile):
return readJSON(jsonFile)
else:
return None
def writeJSONSidecar(niftiFile,paramDict):
def writeJSONSidecar(niftiFile, paramDict):
rePattern = re.compile(r'\.nii(\.gz)?')
jsonFile = rePattern.sub('.json', niftiFile)
writeJSON(jsonFile,paramDict)
writeJSON(jsonFile, paramDict)
def readJSON(file):
with open(file,'r') as jsonFile:
with open(file, 'r') as jsonFile:
jsonString = jsonFile.read()
return json.loads(jsonString)
def writeJSON(fileOut,outputDict):
def writeJSON(fileOut, outputDict):
with open(fileOut, 'w', encoding='utf-8') as f:
json.dump(outputDict, f, ensure_ascii=False, indent='\t')
# Read a folder containing json files in the FSL basis style.
# Optionally allows recalculation of the FID using the stored density matrix.
# This will take longer but avoids the need for interpolation.
# It also allows for arbitrary shifting of the readout central frequency
def readFSLBasisFiles(basisFolder,readoutShift=4.65,bandwidth=None,points=None):
def readFSLBasisFiles(basisFolder, readoutShift=4.65, bandwidth=None, points=None):
if not os.path.isdir(basisFolder):
raise ValueError(' ''basisFolder'' must be a folder containing basis json files.')
# loop through all files in folder
basisfiles = sorted(glob.glob(os.path.join(basisFolder,'*.json')))
basis,names,header = [],[],[]
for bfile in basisfiles:
# loop through all files in folder
basisfiles = sorted(glob.glob(os.path.join(basisFolder, '*.json')))
basis, names, header = [], [], []
for bfile in basisfiles:
if bandwidth is None or points is None:
# If simple read operation call readFSLBasis
b, n, h = readFSLBasis(bfile)
......@@ -125,60 +106,62 @@ def readFSLBasisFiles(basisFolder,readoutShift=4.65,bandwidth=None,points=None):
header.append(h)
else:
# If recalculation requested loop through files calling readAndGenFSLBasis
b, n, h = readAndGenFSLBasis(bfile,readoutShift,bandwidth,points)
b, n, h = readAndGenFSLBasis(bfile, readoutShift, bandwidth, points)
basis.append(b)
names.append(n)
header.append(h)
basis = np.array(basis).conj().T
return basis,names,header
basis = np.array(basis).conj().T
return basis, names, header
# Read the FID within the FSL basis json file. Returns equivalent outputs to the LCModel style basis files.
def readFSLBasis(filename,N=None,dofft=False):
with open(filename,'r') as basisFile:
def readFSLBasis(filename, N=None, dofft=False):
with open(filename, 'r') as basisFile:
jsonString = basisFile.read()
basisFileParams = json.loads(jsonString)
if 'basis' in basisFileParams:
basis = basisFileParams['basis']
data = np.array(basis['basis_re'])+1j*np.array(basis['basis_im'])
if dofft: # Go to frequency domain from timedomain
data = np.array(basis['basis_re']) + 1j * np.array(basis['basis_im'])
if dofft: # Go to frequency domain from timedomain
data = np.fft.fftshift(np.fft.fft(data))
# Resample if necessary? --> should not be allowed actually
if N is not None:
if N != data.shape[0]:
data = ss.resample(data,N)
data = ss.resample(data, N)
header = {'centralFrequency':basis['basis_centre']*1E6,
'bandwidth':1/basis['basis_dwell'],
'dwelltime':basis['basis_dwell'],
'fwhm':basis['basis_width']}
header = {'centralFrequency': basis['basis_centre'] * 1E6,
'bandwidth': 1 / basis['basis_dwell'],
'dwelltime': basis['basis_dwell'],
'fwhm': basis['basis_width']}
# header['echotime'] Not clear how to calculate this in the general case.
metabo = basis['basis_name']
else: #No basis information found
else: # No basis information found
raise ValueError('FSL basis file must have a ''basis'' field.')
return data, metabo, header
# Load an FSL basis file.
# Recalculate the FID on a defined time axis.
# Relies on all fields being populated appropriately
def readAndGenFSLBasis(file,readoutShift,bandwidth,points):
def readAndGenFSLBasis(file, readoutShift, bandwidth, points):
from fsl_mrs.denmatsim import utils as simutils
with open(file,'r') as basisFile:
with open(file, 'r') as basisFile:
jsonString = basisFile.read()
basisFileParams = json.loads(jsonString)
if 'MM' in basisFileParams:
import fsl_mrs.utils.misc as misc
FID, metabo, header = readFSLBasis(file)
old_dt = 1/header['bandwidth']
new_dt = 1/bandwidth
FID = misc.ts_to_ts(FID,old_dt,new_dt,points)
old_dt = 1 / header['bandwidth']
new_dt = 1 / bandwidth
FID = misc.ts_to_ts(FID, old_dt, new_dt, points)
return FID, metabo, header
if 'seq' not in basisFileParams:
......@@ -189,7 +172,7 @@ def readAndGenFSLBasis(file,readoutShift,bandwidth,points):
rxphs = 0
else:
rxphs = basisFileParams['seq']['Rx_Phase']
if 'spinSys' not in basisFileParams:
raise ValueError('To recalculate the basis json must contain a spinSys field.')
else:
......@@ -199,19 +182,27 @@ def readAndGenFSLBasis(file,readoutShift,bandwidth,points):
raise ValueError('To recalculate the basis json must contain a outputDensityMatrix field.')
else:
p = []
for re,im in zip(basisFileParams['outputDensityMatrix']['re'],basisFileParams['outputDensityMatrix']['im']):
p.append(np.array(re)+1j*np.array(im))
if len(p)==1:
p = p[0] # deal with single spin system case
for real, imag in zip(basisFileParams['outputDensityMatrix']['re'],
basisFileParams['outputDensityMatrix']['im']):
p.append(np.array(real) + 1j * np.array(imag))
if len(p) == 1:
p = p[0] # deal with single spin system case
lw = basisFileParams['basis']['basis_width']
FID = simutils.FIDFromDensityMat(p,spins,B0,points,1/bandwidth,lw,offset=readoutShift,recieverPhs=rxphs)
FID = simutils.FIDFromDensityMat(p,
spins,
B0,
points,
1 / bandwidth,
lw,
offset=readoutShift,
recieverPhs=rxphs)
metabo = basisFileParams['basis']['basis_name']
cf = basisFileParams['basis']['basis_centre']*1E6
header = {'centralFrequency':cf,
'bandwidth':bandwidth,
'dwelltime':1/bandwidth,
'fwhm':lw}
cf = basisFileParams['basis']['basis_centre'] * 1E6
header = {'centralFrequency': cf,
'bandwidth': bandwidth,
'dwelltime': 1 / bandwidth,
'fwhm': lw}
return FID, metabo, header
......@@ -3,46 +3,47 @@
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# Will Clarke <william.clarke@ndcn.ox.ac.uk>
#
# Copyright (C) 2020 University of Oxford
# Copyright (C) 2020 University of Oxford
# SHBASECOPYRIGHT
import numpy as np
import re
import os.path as op
from fsl_mrs.core.NIFTI_MRS import gen_new_nifti_mrs
# Read jMRUI style text files
def readjMRUItxt(filename,unpack_header=True):
def readjMRUItxt(filename, unpack_header=True):
"""
Read .txt format file
Parameters
----------
filename : string
Name of jmrui .txt file
Returns
-------
array-like
Complex data
list (or dict if unpack_header==True)
Header information
"""
signalRe = re.compile(r'Signal (\d{1,}) out of (\d{1,}) in file')
headerRe = re.compile(r'(\w*):(.*)')
header = {}
data = []
data = []
recordData = False
with open(filename,'r') as txtfile:
with open(filename, 'r') as txtfile:
for line in txtfile:
headerComp = headerRe.match(line)
if headerComp:
value = headerComp[2].strip()
header.update({headerComp[1]:num(value)})
value = headerComp[2].strip()
header.update({headerComp[1]: num(value)})
signalIndices = signalRe.match(line)
if signalIndices:
recordData = True
continue
if recordData:
curr_data = line.split()
if len(curr_data) > 2:
......@@ -51,21 +52,28 @@ def readjMRUItxt(filename,unpack_header=True):
# Reshape data
data = np.concatenate([np.array(i) for i in data])
data = (data[0::2] + 1j*data[1::2]).astype(np.complex)
data = (data[0::2] + 1j * data[1::2]).astype(np.complex)
# Clean up header
header = translateHeader(header)
return data, header
if 'TypeOfNucleus' in header['jmrui']:
nucleus = header['jmrui']['TypeOfNucleus']
else:
nucleus = '1H'
return gen_new_nifti_mrs(data, header['dwelltime'], header['centralFrequency'], nucleus=nucleus)
# Translate jMRUI header to mandatory fields
def translateHeader(header):
newHeader = {'jmrui':header}
newHeader.update({'centralFrequency':header['TransmitterFrequency']})
newHeader.update({'bandwidth':1/(header['SamplingInterval']*1E-3)})
newHeader.update({'dwelltime':header['SamplingInterval']*1E-3})
newHeader = {'jmrui': header}
newHeader.update({'centralFrequency': header['TransmitterFrequency']})
newHeader.update({'bandwidth': 1 / (header['SamplingInterval'] * 1E-3)})
newHeader.update({'dwelltime': header['SamplingInterval'] * 1E-3})
return newHeader
def num(s):
try:
return int(s)
......@@ -75,13 +83,14 @@ def num(s):
except ValueError:
return s
# Read jMRUI .txt files containing basis
def read_txtBasis_files(txtfiles):