Commit 796eae2c authored by William Clarke's avatar William Clarke
Browse files

Rework of mrs_io for new format. test_scripts_proc passes.

parent c47449b4
......@@ -2,5 +2,5 @@ from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
from fsl_mrs.core import MRS
from fsl_mrs.core import MRSI
\ No newline at end of file
# from fsl_mrs.core import MRS
# from fsl_mrs.core import MRSI
\ No newline at end of file
......@@ -13,7 +13,7 @@ from fsl_mrs.core import MRS
from fsl_mrs.utils import mrs_io, misc
import matplotlib.pyplot as plt
import nibabel as nib
from fsl_mrs.utils.mrs_io.fsl_io import saveNIFTI, readNIFTI
from fsl_mrs.utils.mrs_io import fsl_io
class MRSI(object):
......@@ -288,7 +288,7 @@ class MRSI(object):
data[data > 1e10] = 0
if nt == self.FID_points:
saveNIFTI(file_path_name, data, self.header)
fsl_io.saveNIFTI(file_path_name, data, self.header)
img = nib.Nifti1Image(data, self.header['nifti'].affine), file_path_name)
......@@ -304,7 +304,8 @@ class MRSI(object):
""" Load MRSI data directly from files """
data, hdr = mrs_io.read_FID(data_file)
if mask_file is not None:
mask, _ = readNIFTI(mask_file)
nib_img = nib.load(mask_file)
mask = np.asanyarray(nib_img.dataobj)
mask = None
......@@ -6,12 +6,14 @@
# Copyright (C) 2021 University of Oxford
import nibabel as nib
import numpy as np
import json
from import Image
from fsl_mrs.core import MRS, MRSI
import fsl.utils.path as fslpath
from nibabel.nifti2 import Nifti2Header
from nibabel.nifti1 import Nifti1Extension
from fsl_mrs.utils.misc import checkCFUnits
class NIFTIMRS_DimDoesntExist(Exception):
......@@ -25,12 +27,51 @@ class NotNIFTI_MRS(Exception):
def is_nifti_mrs(file_path):
'''Check that a file is of the NIFTI-MRS format type.'''
obj = NIFTI_MRS(file_path)
if float(obj.mrs_nifti_version) < 0.2:
raise NotNIFTI_MRS('NIFTI-MRS > v0.2 required.')
return True
except fslpath.PathError:
raise NotNIFTI_MRS("File isn't NIFTI-MRS, wrong extension type.")
raise NotNIFTI_MRS("File isn't NIFTI-MRS, wrong extension type.")
def gen_new_nifti_mrs(data, dwelltime, spec_freq, nucleus='1H', affine=None, dim_tags=[None, None, None]):
'''Generate a NIFTI_MRS object from a np array and header info.
:param np.ndarray data: FID (time-domain) data. Must be atleast 4D.
:param float dwelltime: Dwelltime of FID data in seconds.
:param float spec_freq: Spectrometer (or central) frequency in MHz.
:param str nucleus: Nucleus string, defaults to '1H'
:param np.ndarray affine: Optional 4x4 position affine.
:param [str] affine: List of dimension tags.
:return: NIFTI_MRS object
if not np.iscomplex(data).all():
raise ValueError('data must be complex')
if data.ndim < 4 or data.ndim > 7:
raise ValueError(f'data must between 4 and 7 dimensions, currently has {data.ndim}')
header = Nifti2Header()
header['pixdim'][4] = dwelltime
hdr_dict = {'SpectrometerFrequency': [checkCFUnits(spec_freq, units='MHz'), ],
'ResonantNucleus': [nucleus, ]}
for idx, dt in enumerate(dim_tags):
if dt is not None:
if (idx + 4) > data.ndim:
raise ValueError('Too many dimension tags passed.')
hdr_dict[f'dim_{idx+5}'] = dt
json_s = json.dumps(hdr_dict)
extension = Nifti1Extension(44, json_s.encode('UTF-8'))
header['intent_name'] = 'mrs_v0_2'.encode()
return NIFTI_MRS(data, header=header)
class NIFTI_MRS(Image):
......@@ -43,6 +84,19 @@ class NIFTI_MRS(Image):
args[0] = args[0].conj()
super().__init__(*args, **kwargs)
# Check that file meets minimum requirements
if float(self.mrs_nifti_version) < 0.2:
raise NotNIFTI_MRS('NIFTI-MRS > V0.2 required.')
if 44 not in self.header.extensions.get_codes():
raise NotNIFTI_MRS('NIFTI-MRS must have a header extension.')
except KeyError:
raise NotNIFTI_MRS('NIFTI-MRS header extension must have nucleus and spectrometerFrequency keys.')
# Extract key parameters from the header extension
......@@ -113,7 +167,7 @@ class NIFTI_MRS(Image):
def hdr_ext(self, hdr_dict):
'''Update MRS JSON header extension from python dict'''
json_s = json.dumps(hdr_dict)
extension = nib.nifti1.Nifti1Extension(44, json_s.encode('UTF-8'))
extension = Nifti1Extension(44, json_s.encode('UTF-8'))
......@@ -82,8 +82,9 @@ def main():
align_group = alignparser.add_argument_group('Align arguments')
align_group.add_argument('--file', type=str, required=True,
help='List of files to align')
align_group.add_argument('--dim', type=str, required=True,
help='NIFTI-MRS dimension tag to align across')
align_group.add_argument('--dim', type=str, default='DIM_DYN',
help='NIFTI-MRS dimension tag to align across.'
'Default = DIM_DYN')
align_group.add_argument('--ppm', type=float, nargs=2,
metavar='<lower-limit upper-limit>',
default=(0.2, 4.2),
......@@ -102,7 +103,7 @@ def main():
alignD_group = alignDparser.add_argument_group('Align subspec arguments')
alignD_group.add_argument('--file', type=str, required=True,
help='Subspectra 1 - List of files to align')
alignD_group.add_argument('--dim', type=str, required=True,
alignD_group.add_argument('--dim', type=str, default='DIM_DYN',
help='NIFTI-MRS dimension tag to align across')
alignD_group.add_argument('--dim_diff', type=str, default='DIM_EDIT',
help='NIFTI-MRS dimension tag to difference across')
......@@ -310,8 +311,7 @@ def main():
# Handle data loading
dataList = loadData(datafiles,
# Create output folder if required
if not op.isdir(args.output):
......@@ -363,8 +363,8 @@ def add_common_args(p):
optional.add_argument('--allreports', action="store_true",
help='Generate reports for all inputs.'
' Overrides arguments to reportIndicies.')
optional.add_argument('--conjugate', action="store_true",
help='apply conjugate to FID')
# optional.add_argument('--conjugate', action="store_true",
# help='apply conjugate to FID')
optional.add_argument('--filename', type=str, metavar='<str>',
help='Override output file name.')
optional.add_argument('--verbose', action="store_true",
......@@ -405,14 +405,13 @@ def loadData(datafile, refdatafile=None):
' spec2nii.')
if refdatafile:
loaded_data = datacontainer(NIFTI_MRS(datafile),
loaded_data = datacontainer(NIFTI_MRS(datafile),
loaded_data = datacontainer(NIFTI_MRS(datafile),
return loaded_data
......@@ -641,9 +640,9 @@ def add(dataobj, args):
def conj(dataobj, args):
conjugated = preproc.apply_fixed_phase(,
conjugated = preproc.conjugate(,
return datacontainer(conjugated, dataobj.datafilename)
This diff is collapsed.
from fsl_mrs.utils.mrs_io.main import read_FID, read_basis, check_datatype
\ No newline at end of file
from fsl_mrs.utils.mrs_io.main import read_FID, read_basis, check_datatype
......@@ -3,27 +3,25 @@
# Author: Saad Jbabdi <>
# Will Clarke <>
# Copyright (C) 2020 University of Oxford
# Copyright (C) 2020 University of Oxford
import numpy as np
import json
import nibabel as nib
import sys, os, glob
import os
import glob
import re
import scipy.signal as ss
from fsl_mrs.core.NIFTI_MRS import gen_new_nifti_mrs
def readNIFTI(datafile,squeezeSVS=True):
""" Read nifti format file.
datafile (str)
squeezeSVS (optional,bool)
def readNIFTI(datafile):
""" Read old (pre NIFTI-MRS) nifti format file.
:param str datafile: Path to file
:return: NIFTI-MRS
data_hdr = nib.load(datafile)
data = np.asanyarray(data_hdr.dataobj)
......@@ -35,88 +33,71 @@ def readNIFTI(datafile,squeezeSVS=True):
# Reciever bandwidth and dwelltime can either be fetched from the nifti header or the json
# central frequency is currently only sotred in the json.
if jsonParams is None:
dwell = data_hdr.header['pixdim'][4]
bw = 1/dwell
header ={'nifti':data_hdr,'centralFrequency':None,'dwelltime':dwell,'bandwidth':bw}
header ={'nifti':data_hdr,
if "EchoTime" in jsonParams:
header['TE'] = jsonParams['EchoTime']
# If there is only one FID (SVS) and squeezeSVS is true then
# remove singleton dimensions
numVoxels = np.product(data.shape[0:3])
if numVoxels==1 and squeezeSVS:
data = np.squeeze(data)
return data,header
def saveNIFTI(datafile,data,header,affine=None):
if 'nifti' not in header and affine is None:
raise ValueError('To save a nifti file the header must contain a nifti field or an affine must be specifed')
if affine is not None:
affineToUse = affine
raise Exception('Unable to load files without JSON sidecar')
affineToUse = header['nifti'].affine
if data.ndim == 1:
data = data.reshape((1,1,1,data.size))
if 'Dwelltime' in jsonParams:
dwelltime = jsonParams['Dwelltime']
dwelltime = data_hdr.header['pixdim'][4]
spec_freq = jsonParams['ImagingFrequency']
if 'ResonantNucleus' in jsonParams:
nucleus = jsonParams['ResonantNucleus']
nucleus = '1H'
img = nib.Nifti2Image(data,affine=affineToUse)
return gen_new_nifti_mrs(data, dwelltime, spec_freq, nucleus=nucleus, affine=data_hdr.affine)
# insert the correct dwell time into the nifti file, it will then plot in fsleyes with the correct faxis
img.header['pixdim'][4] = header['dwelltime'],datafile)
if 'json' in header:
elif ('centralFrequency' in header) and (header['centralFrequency'] is not None): # Store the essential parameters
jsonheader ={'ImagingFrequency':header['centralFrequency'],
def saveNIFTI(filepath, data, header, affine=None):
'''Provide translation layer from old interface to new NIFTI_MRS'''
1 / header['bandwidth'],
# JSON sidecar I/O
def readJSONSidecar(niftiFile):
# Determine if there is a json file
rePattern = re.compile(r'\.nii(\.gz)?')
jsonFile = rePattern.sub('.json', niftiFile)
if os.path.isfile(jsonFile):
if os.path.isfile(jsonFile):
return readJSON(jsonFile)
return None
def writeJSONSidecar(niftiFile,paramDict):
def writeJSONSidecar(niftiFile, paramDict):
rePattern = re.compile(r'\.nii(\.gz)?')
jsonFile = rePattern.sub('.json', niftiFile)
writeJSON(jsonFile, paramDict)
def readJSON(file):
with open(file,'r') as jsonFile:
with open(file, 'r') as jsonFile:
jsonString =
return json.loads(jsonString)
def writeJSON(fileOut,outputDict):
def writeJSON(fileOut, outputDict):
with open(fileOut, 'w', encoding='utf-8') as f:
json.dump(outputDict, f, ensure_ascii=False, indent='\t')
# Read a folder containing json files in the FSL basis style.
# Optionally allows recalculation of the FID using the stored density matrix.
# This will take longer but avoids the need for interpolation.
# It also allows for arbitrary shifting of the readout central frequency
def readFSLBasisFiles(basisFolder,readoutShift=4.65,bandwidth=None,points=None):
def readFSLBasisFiles(basisFolder, readoutShift=4.65, bandwidth=None, points=None):
if not os.path.isdir(basisFolder):
raise ValueError(' ''basisFolder'' must be a folder containing basis json files.')
# loop through all files in folder
basisfiles = sorted(glob.glob(os.path.join(basisFolder,'*.json')))
basis,names,header = [],[],[]
for bfile in basisfiles:
# loop through all files in folder
basisfiles = sorted(glob.glob(os.path.join(basisFolder, '*.json')))
basis, names, header = [], [], []
for bfile in basisfiles:
if bandwidth is None or points is None:
# If simple read operation call readFSLBasis
b, n, h = readFSLBasis(bfile)
......@@ -125,60 +106,62 @@ def readFSLBasisFiles(basisFolder,readoutShift=4.65,bandwidth=None,points=None):
# If recalculation requested loop through files calling readAndGenFSLBasis
b, n, h = readAndGenFSLBasis(bfile,readoutShift,bandwidth,points)
b, n, h = readAndGenFSLBasis(bfile, readoutShift, bandwidth, points)
basis = np.array(basis).conj().T
return basis,names,header
basis = np.array(basis).conj().T
return basis, names, header
# Read the FID within the FSL basis json file. Returns equivalent outputs to the LCModel style basis files.
def readFSLBasis(filename,N=None,dofft=False):
with open(filename,'r') as basisFile:
def readFSLBasis(filename, N=None, dofft=False):
with open(filename, 'r') as basisFile:
jsonString =
basisFileParams = json.loads(jsonString)
if 'basis' in basisFileParams:
basis = basisFileParams['basis']
data = np.array(basis['basis_re'])+1j*np.array(basis['basis_im'])
if dofft: # Go to frequency domain from timedomain
data = np.array(basis['basis_re']) + 1j * np.array(basis['basis_im'])
if dofft: # Go to frequency domain from timedomain
data = np.fft.fftshift(np.fft.fft(data))
# Resample if necessary? --> should not be allowed actually
if N is not None:
if N != data.shape[0]:
data = ss.resample(data,N)
data = ss.resample(data, N)
header = {'centralFrequency':basis['basis_centre']*1E6,
header = {'centralFrequency': basis['basis_centre'] * 1E6,
'bandwidth': 1 / basis['basis_dwell'],
'dwelltime': basis['basis_dwell'],
'fwhm': basis['basis_width']}
# header['echotime'] Not clear how to calculate this in the general case.
metabo = basis['basis_name']
else: #No basis information found
else: # No basis information found
raise ValueError('FSL basis file must have a ''basis'' field.')
return data, metabo, header
# Load an FSL basis file.
# Recalculate the FID on a defined time axis.
# Relies on all fields being populated appropriately
def readAndGenFSLBasis(file,readoutShift,bandwidth,points):
def readAndGenFSLBasis(file, readoutShift, bandwidth, points):
from fsl_mrs.denmatsim import utils as simutils
with open(file,'r') as basisFile:
with open(file, 'r') as basisFile:
jsonString =
basisFileParams = json.loads(jsonString)
if 'MM' in basisFileParams:
import fsl_mrs.utils.misc as misc
FID, metabo, header = readFSLBasis(file)
old_dt = 1/header['bandwidth']
new_dt = 1/bandwidth
FID = misc.ts_to_ts(FID,old_dt,new_dt,points)
old_dt = 1 / header['bandwidth']
new_dt = 1 / bandwidth
FID = misc.ts_to_ts(FID, old_dt, new_dt, points)
return FID, metabo, header
if 'seq' not in basisFileParams:
......@@ -189,7 +172,7 @@ def readAndGenFSLBasis(file,readoutShift,bandwidth,points):
rxphs = 0
rxphs = basisFileParams['seq']['Rx_Phase']
if 'spinSys' not in basisFileParams:
raise ValueError('To recalculate the basis json must contain a spinSys field.')
......@@ -199,19 +182,27 @@ def readAndGenFSLBasis(file,readoutShift,bandwidth,points):
raise ValueError('To recalculate the basis json must contain a outputDensityMatrix field.')
p = []
for re,im in zip(basisFileParams['outputDensityMatrix']['re'],basisFileParams['outputDensityMatrix']['im']):
if len(p)==1:
p = p[0] # deal with single spin system case
for real, imag in zip(basisFileParams['outputDensityMatrix']['re'],
p.append(np.array(real) + 1j * np.array(imag))
if len(p) == 1:
p = p[0] # deal with single spin system case
lw = basisFileParams['basis']['basis_width']
FID = simutils.FIDFromDensityMat(p,spins,B0,points,1/bandwidth,lw,offset=readoutShift,recieverPhs=rxphs)
FID = simutils.FIDFromDensityMat(p,
1 / bandwidth,
metabo = basisFileParams['basis']['basis_name']
cf = basisFileParams['basis']['basis_centre']*1E6
header = {'centralFrequency':cf,
cf = basisFileParams['basis']['basis_centre'] * 1E6
header = {'centralFrequency': cf,
'bandwidth': bandwidth,
'dwelltime': 1 / bandwidth,
'fwhm': lw}
return FID, metabo, header
......@@ -3,46 +3,47 @@
# Author: Saad Jbabdi <>
# Will Clarke <>
# Copyright (C) 2020 University of Oxford
# Copyright (C) 2020 University of Oxford
import numpy as np
import re
import os.path as op
from fsl_mrs.core.NIFTI_MRS import gen_new_nifti_mrs
# Read jMRUI style text files
def readjMRUItxt(filename,unpack_header=True):
def readjMRUItxt(filename, unpack_header=True):
Read .txt format file
filename : string
Name of jmrui .txt file
Complex data
list (or dict if unpack_header==True)
Header information
signalRe = re.compile(r'Signal (\d{1,}) out of (\d{1,}) in file')
headerRe = re.compile(r'(\w*):(.*)')
header = {}
data = []
data = []
recordData = False
with open(filename,'r') as txtfile:
with open(filename, 'r') as txtfile:
for line in txtfile:
headerComp = headerRe.match(line)
if headerComp:
value = headerComp[2].strip()
value = headerComp[2].strip()
header.update({headerComp[1]: num(value)})
signalIndices = signalRe.match(line)
if signalIndices:
recordData = True
if recordData:
curr_data = line.split()
if len(curr_data) > 2:
......@@ -51,21 +52,28 @@ def readjMRUItxt(filename,unpack_header=True):
# Reshape data
data = np.concatenate([np.array(i) for i in data])
data = (data[0::2] + 1j*data[1::2]).astype(np.complex)
data = (data[0::2] + 1j * data[1::2]).astype(np.complex)
# Clean up header
header = translateHeader(header)
return data, header
if 'TypeOfNucleus' in header['jmrui']:
nucleus = header['jmrui']['TypeOfNucleus']
nucleus = '1H'
return gen_new_nifti_mrs(data, header['dwelltime'], header['centralFrequency'], nucleus=nucleus)
# Translate jMRUI header to mandatory fields
def translateHeader(header):
newHeader = {'jmrui':header}
newHeader = {'jmrui': header}
newHeader.update({'centralFrequency': header['TransmitterFrequency']})
newHeader.update({'bandwidth': 1 / (header['SamplingInterval'] * 1E-3)})
newHeader.update({'dwelltime': header['SamplingInterval'] * 1E-3})
return newHeader
def num(s):
return int(s)
......@@ -75,13 +83,14 @@ def num(s):
except ValueError:
return s
# Read jMRUI .txt files containing basis
def read_txtBasis_files(txtfiles):