Skip to content
Snippets Groups Projects
Commit 1b8fe46d authored by William Clarke's avatar William Clarke
Browse files

Checked tests after adding MRSI.

parent a2b45421
No related branches found
No related tags found
No related merge requests found
Source diff could not be displayed: it is stored in LFS. Options to address this: view the blob.
Source diff could not be displayed: it is stored in LFS. Options to address this: view the blob.
......@@ -21,8 +21,6 @@ def test_read_FID_SVS():
data_nifti,header_nifti = mrsio.read_FID(SVSTestData['nifti'])
data_raw,header_raw = mrsio.read_FID(SVSTestData['raw'])
data_txt,header_txt = mrsio.read_FID(SVSTestData['txt'])
data_raw = data_raw.conj()
# Check that the data from each of these matches - it should they are all the same bit of data.
datamean = np.mean([data_nifti,data_raw,data_txt],axis=0)
......@@ -95,7 +93,10 @@ def test_read_Basis():
assert np.isclose(headers_txt[0][r],headerMean)
assert np.isclose(headers_lcm[0][r],headerMean)
# Conjugate fsl and jMRUI
basis_fsl = basis_fsl.conj()
basis_txt = basis_txt.conj()
# Test that all contain roughly the same data when scaled.
metabToCheck = 'Cr'
checkIdx = names_raw.index('Cr')
......@@ -105,11 +106,17 @@ def test_read_Basis():
convertToLimitedSpec(basis_raw[:,checkIdx]),
convertToLimitedSpec(basis_txt[:,checkIdx]),
convertToLimitedSpec(basis_lcm[:,checkIdx])],axis=0)
assert np.isclose(convertToLimitedSpec(basis_fsl[:,checkIdx]),meanSpec,rtol=2e-01, atol=1e-03).all()
assert np.isclose(convertToLimitedSpec(basis_raw[:,checkIdx]),meanSpec,rtol=2e-01, atol=1e-03).all
assert np.isclose(convertToLimitedSpec(basis_txt[:,checkIdx]),meanSpec,rtol=2e-01, atol=1e-03).all
assert np.isclose(convertToLimitedSpec(basis_lcm[:,checkIdx]),meanSpec,rtol=2e-01, atol=1e-03).all
# breakpoint()
# import matplotlib.pyplot as plt
# plt.plot(convertToLimitedSpec(basis_fsl[:,checkIdx]))
# plt.plot(convertToLimitedSpec(basis_raw[:,checkIdx]),'--')
# plt.plot(convertToLimitedSpec(basis_txt[:,checkIdx]),'-.')
# plt.plot(convertToLimitedSpec(basis_lcm[:,checkIdx]),':')
# plt.show()
assert np.allclose(convertToLimitedSpec(basis_fsl[:,checkIdx]),meanSpec,rtol=2e-01, atol=1e-03)
assert np.allclose(convertToLimitedSpec(basis_raw[:,checkIdx]),meanSpec,rtol=2e-01, atol=1e-03)
assert np.allclose(convertToLimitedSpec(basis_txt[:,checkIdx]),meanSpec,rtol=2e-01, atol=1e-03)
assert np.allclose(convertToLimitedSpec(basis_lcm[:,checkIdx]),meanSpec,rtol=2e-01, atol=1e-03)
def test_fslBasisRegen():
pointsToGen = 10
......
......@@ -24,7 +24,8 @@ def test_preproc(tmp_path):
['--ecc',]+allfiles_ecc+
['--hlsvd',
'--leftshift','1',
'--overwrite'])
'--overwrite',
'--report' ])
print(retcode)
assert retcode==0
assert op.isfile(op.join(tmp_path,'coil_combined.png'))
\ No newline at end of file
assert op.isfile(op.join(tmp_path,'mergedReports.html'))
\ No newline at end of file
......@@ -50,9 +50,8 @@ def test_fit_FSLModel_Newton(data):
res = fit_FSLModel(mrs,**Fitargs)
fittedconcs = res.getConc()
fittedRelconcs = res.getConc(scaling='internal')
fittedconcs = res.getConc(metab = mrs.names)
fittedRelconcs = res.getConc(scaling='internal',metab = mrs.names)
assert np.allclose(fittedconcs,amplitudes,atol=1E-1)
assert np.allclose(fittedRelconcs,amplitudes/(amplitudes[0]+amplitudes[1]),atol=1E-1)
......@@ -69,8 +68,8 @@ def test_fit_FSLModel_MH(data):
res = fit_FSLModel(mrs,**Fitargs)
fittedconcs = res.getConc()
fittedRelconcs = res.getConc(scaling='internal')
fittedconcs = res.getConc(metab = mrs.names)
fittedRelconcs = res.getConc(scaling='internal',metab = mrs.names)
assert np.allclose(fittedconcs,amplitudes,atol=1E-1)
assert np.allclose(fittedRelconcs,amplitudes/(amplitudes[0]+amplitudes[1]),atol=1E-1)
\ No newline at end of file
......@@ -14,7 +14,12 @@ def test_ppm2hz_hz2ppm():
assert misc.hz2ppm(cf,hz,shift=True)==(1.0+shift)
def test_FIDToSpec_SpecToFID():
testFID,hdr = synth.syntheticFID(amplitude=[1],chemicalshift=[0],phase=[0],damping=[20])
testFID,hdr = synth.syntheticFID(amplitude=[10],chemicalshift=[0],phase=[0],damping=[20])
# SVS case
spec = misc.FIDToSpec(testFID[0])
reformedFID = misc.SpecToFID(spec)
assert np.allclose(reformedFID,testFID)
testMRSI = np.tile(testFID,(4,4,4,1)).T
testspec = misc.FIDToSpec(testMRSI)
......@@ -24,11 +29,11 @@ def test_FIDToSpec_SpecToFID():
assert np.argmax(np.abs(testspec[2,2,2,:]))==1024
reformedFID = misc.SpecToFID(testspec,axis=3)
assert np.isclose(reformedFID,testMRSI.T).all()
assert np.allclose(reformedFID,testMRSI.T)
reformedFID = misc.SpecToFID(testspec.T)
assert np.isclose(reformedFID,testMRSI).all()
assert np.allclose(reformedFID,testMRSI)
# Odd number of points - guard against fftshift/ifftshift errors
testFID,hdr = synth.syntheticFID(amplitude=[1],chemicalshift=[0],phase=[0],damping=[20],points=1025)
assert np.isclose(misc.SpecToFID(misc.FIDToSpec(testFID)),testFID).all()
assert np.allclose(misc.SpecToFID(misc.FIDToSpec(testFID[0])),testFID)
......@@ -17,18 +17,17 @@ def test_freqshift():
assert freqOfMax < 5 and freqOfMax > -5
# Test timeshift by 1) counting points, 2) undoing 1st order phase of fid with begin time.
# Test timeshift
def test_timeshift():
# Create data with lots of points and some begin time delay
testFID,testHdrs = syn.syntheticFID(begintime=0.001,points=4096,noisecovariance=[[0.0]])
assert ~(np.real(FIDToSpec(testFID))>0.0).all() # Check starting conditions
testFID,testHdrs = syn.syntheticFID(begintime=-0.001,points=4096,noisecovariance=[[0.0]])
testFID2,testHdrs2 = syn.syntheticFID(begintime=0.000,points=4096,noisecovariance=[[0.0]])
# Reduce points and pad to remove first order phase
shiftedFID,_ = preproc.timeshift(testFID[0],1/testHdrs['inputopts']['bandwidth'],-0.001,0.0,samples=2048)
shiftedFID,_ = preproc.timeshift(testFID[0],1/testHdrs['inputopts']['bandwidth'],0.001,0.0,samples=4096)
assert shiftedFID.size == 2048
assert (np.real(FIDToSpec(shiftedFID)+0.005)>0.0).all()
# assert shiftedFID.size == 2048
assert np.allclose(shiftedFID,testFID2[0],atol=1E-1)
# Test combine_FIDs:
# Test mean by calculating mean of anti phase signals
......@@ -169,11 +168,10 @@ def test_phaseCorrect():
def test_add_subtract():
mockFID = np.random.random(1024)+1j*np.random.random(1024)
mockFID2 = mockFID.copy()
testFID = preproc.add(mockFID,mockFID2)
assert np.allclose(testFID,mockFID*2.0)
testFID = preproc.add(mockFID.copy(),mockFID.copy())
assert np.allclose(testFID,(mockFID*2.0)/2.0) # Averaging op
testFID = preproc.subtract(mockFID,mockFID2)
testFID = preproc.subtract(mockFID.copy(),mockFID.copy())
assert np.allclose(testFID,np.zeros(1024))
def test_align_diff():
......
......@@ -52,4 +52,4 @@ def test_quantifyWater():
print(res.getConc(scaling='molarity'))
assert np.allclose(res.getConc(scaling='internal'),1.0)
assert np.allclose(res.getConc(scaling='molarity'),10.72,atol=1E-1)
\ No newline at end of file
assert np.allclose(res.getConc(scaling='molarity'),10.59,atol=1E-1)
\ No newline at end of file
......@@ -17,7 +17,7 @@ def data():
phases = [0,0,0]
g = [0,0,0]
basisNames = ['Cr','PCr','NAA']
begintime = 0.0001
begintime = 0.00005
basisFIDs = []
for idx,_ in enumerate(amplitude):
......@@ -27,7 +27,7 @@ def data():
linewidth=[lw[idx]/5],
phase=[phases[idx]],
g=[g[idx]],
begintime=begintime)
begintime=0)
basisFIDs.append(tmp[0])
basisFIDs = np.asarray(basisFIDs)
......@@ -36,7 +36,8 @@ def data():
amplitude=amplitude,
linewidth=lw,
phase=phases,
g=g)
g=g,
begintime=begintime)
synMRS = MRS(FID =synFID[0],header=synHdr,basis =basisFIDs,basis_hdr=basisHdr,names=basisNames)
......@@ -44,7 +45,8 @@ def data():
Fitargs = {'ppmlim':[0.2,4.2],
'method':'MH','baseline_order':-1,
'metab_groups':metab_groups,
'MHSamples':100}
'MHSamples':100,
'disable_mh_priors':True}
res = fit_FSLModel(synMRS,**Fitargs)
......@@ -64,7 +66,7 @@ def test_peakcombination(data):
assert 'Cr+PCr' in res.metabs
assert np.allclose(fittedconcs,amplitudes,atol=1E-1)
assert np.allclose(fittedRelconcs,amplitudes/(amplitudes[0]+amplitudes[1]),atol=1E-1)
assert np.allclose(fittedRelconcs,amplitudes/(amplitudes[0]+amplitudes[1]),atol=2E-1)
def test_units(data):
res = data[0]
......@@ -72,7 +74,7 @@ def test_units(data):
# Phase
p0,p1 = res.getPhaseParams(phi0='degrees',phi1='seconds')
assert np.isclose(p0,0.0,atol=1E-1)
assert np.isclose(p1,0.0001,atol=1E-5)
assert np.isclose(p1,0.00005,atol=3E-5)
# Shift
shift = res.getShiftParams(units='ppm')
......@@ -81,8 +83,8 @@ def test_units(data):
assert np.isclose(shift_hz,0.1*123.0,atol=1E-1)
# Linewidth
lw = res.getLineShapeParams(units='Hz')
lw_ppm = res.getLineShapeParams(units='ppm')
lw = res.getLineShapeParams(units='Hz')[0]
lw_ppm = res.getLineShapeParams(units='ppm')[0]
assert np.isclose(lw,8.0,atol=1E-1) #10-2
assert np.isclose(lw_ppm,8.0/123.0,atol=1E-1)
......
......@@ -14,9 +14,9 @@ def test_noisecov():
points= 32768)
outcov = np.cov(np.asarray(testFID))
print(inputnoisecov)
print(outcov)
assert np.isclose(outcov,inputnoisecov,atol=1E-1).all()
# Noise cov is for both real and imag, so multiply by 2
assert np.isclose(outcov,2*inputnoisecov,atol=1E-1).all()
def test_syntheticFID():
testFID,hdr = syn.syntheticFID(noisecovariance=[[0.0]],points=16384)
......@@ -36,4 +36,4 @@ def test_syntheticFID():
spec /= np.max(np.abs(spec))
testSpec /= np.max(np.abs(testSpec))
assert np.isclose(spec,FIDToSpec(testFID),atol = 1E-2,rtol = 1E0).all()
\ No newline at end of file
assert np.isclose(spec,FIDToSpec(testFID[0]),atol = 1E-2,rtol = 1E0).all()
\ No newline at end of file
......@@ -276,6 +276,7 @@ def fit_FSLModel(mrs,
model='lorentzian',
x0=None,
MHSamples=500,
disable_mh_priors = False,
vb_iter=50):
"""
A simplified version of LCModel
......@@ -337,29 +338,31 @@ def fit_FSLModel(mrs,
def loglik(p):
return np.log(np.linalg.norm(y-forward_mh(p)[first:last]))*numPoints_over_2
# def logpr(p):
# return np.sum(dist.gauss_logpdf(p,loc=np.zeros_like(p),scale=np.ones_like(p)*1E2))
def logpr(p):
prior = 0
if model.lower()=='lorentzian':
con,gamma,eps,phi0,phi1,b = x2p(p,mrs.numBasis,g)
prior += np.sum(dist.gauss_logpdf(con,loc=np.zeros_like(con),scale=np.ones_like(con)*1E0))
prior += np.sum(dist.gauss_logpdf(gamma,loc=np.ones_like(gamma)*5*np.pi,scale=np.ones_like(gamma)*2.5*np.pi))
prior += np.sum(dist.gauss_logpdf(eps,loc=np.zeros_like(eps),scale=np.ones_like(eps)*0.005*(2*np.pi*mrs.centralFrequency/1E6)))
prior += np.sum(dist.gauss_logpdf(phi0,loc=np.zeros_like(phi0),scale=np.ones_like(phi0)*(np.pi*10/180)))
prior += np.sum(dist.gauss_logpdf(phi1,loc=np.zeros_like(phi1),scale=np.ones_like(phi1)*(1E-5*2*np.pi)))
prior += 0
elif model.lower()=='voigt':
con,gamma,sigma,eps,phi0,phi1,b = x2p(p,mrs.numBasis,g)
prior += np.sum(dist.gauss_logpdf(con,loc=np.zeros_like(con),scale=np.ones_like(con)*1E0))
prior += np.sum(dist.gauss_logpdf(gamma,loc=np.ones_like(gamma)*5*np.pi,scale=np.ones_like(gamma)*2.5*np.pi))
prior += np.sum(dist.gauss_logpdf(sigma,loc=np.ones_like(sigma)*5*np.pi,scale=np.ones_like(sigma)*2.5*np.pi))
prior += np.sum(dist.gauss_logpdf(eps,loc=np.zeros_like(eps),scale=np.ones_like(eps)*0.005*(2*np.pi*mrs.centralFrequency/1E6)))
prior += np.sum(dist.gauss_logpdf(phi0,loc=np.zeros_like(phi0),scale=np.ones_like(phi0)*(np.pi*5/180)))
prior += np.sum(dist.gauss_logpdf(phi1,loc=np.zeros_like(phi1),scale=np.ones_like(phi1)*(1E-5*2*np.pi)))
prior += 0
return prior
if disable_mh_priors:
def logpr(p):
return np.sum(dist.gauss_logpdf(p,loc=np.zeros_like(p),scale=np.ones_like(p)*1E2))
else:
def logpr(p):
prior = 0
if model.lower()=='lorentzian':
con,gamma,eps,phi0,phi1,b = x2p(p,mrs.numBasis,g)
prior += np.sum(dist.gauss_logpdf(con,loc=np.zeros_like(con),scale=np.ones_like(con)*1E0))
prior += np.sum(dist.gauss_logpdf(gamma,loc=np.ones_like(gamma)*5*np.pi,scale=np.ones_like(gamma)*2.5*np.pi))
prior += np.sum(dist.gauss_logpdf(eps,loc=np.zeros_like(eps),scale=np.ones_like(eps)*0.005*(2*np.pi*mrs.centralFrequency/1E6)))
prior += np.sum(dist.gauss_logpdf(phi0,loc=np.zeros_like(phi0),scale=np.ones_like(phi0)*(np.pi*10/180)))
prior += np.sum(dist.gauss_logpdf(phi1,loc=np.zeros_like(phi1),scale=np.ones_like(phi1)*(1E-5*2*np.pi)))
prior += 0
elif model.lower()=='voigt':
con,gamma,sigma,eps,phi0,phi1,b = x2p(p,mrs.numBasis,g)
prior += np.sum(dist.gauss_logpdf(con,loc=np.zeros_like(con),scale=np.ones_like(con)*1E0))
prior += np.sum(dist.gauss_logpdf(gamma,loc=np.ones_like(gamma)*5*np.pi,scale=np.ones_like(gamma)*2.5*np.pi))
prior += np.sum(dist.gauss_logpdf(sigma,loc=np.ones_like(sigma)*5*np.pi,scale=np.ones_like(sigma)*2.5*np.pi))
prior += np.sum(dist.gauss_logpdf(eps,loc=np.zeros_like(eps),scale=np.ones_like(eps)*0.005*(2*np.pi*mrs.centralFrequency/1E6)))
prior += np.sum(dist.gauss_logpdf(phi0,loc=np.zeros_like(phi0),scale=np.ones_like(phi0)*(np.pi*5/180)))
prior += np.sum(dist.gauss_logpdf(phi1,loc=np.zeros_like(phi1),scale=np.ones_like(phi1)*(1E-5*2*np.pi)))
prior += 0
return prior
#loglik = lambda p : np.log(np.linalg.norm(y-forward_mh(p)[first:last]))*numPoints_over_2
......
......@@ -47,10 +47,13 @@ def FIDToSpec(FID,axis=0):
Returns:
x (np.array) : array of spectra
"""
# By convention the first point of the fid is special cased
FID[0] *=0.5
# By convention the first point of the fid is special cased
ss = [slice(None) for i in range(FID.ndim)]
ss[axis] = slice(0,1)
ss = tuple(ss)
FID[ss] *=0.5
out = scipy.fft.fftshift(scipy.fft.fft(FID,axis=axis,norm='ortho'),axes=axis)
FID[0] *=2
FID[ss] *=2
return out
def SpecToFID(spec,axis=0):
......@@ -65,7 +68,10 @@ def SpecToFID(spec,axis=0):
x (np.array) : array of FIDs
"""
fid = scipy.fft.ifft(scipy.fft.ifftshift(spec,axes=axis),axis=axis,norm='ortho')
fid[0] *= 2
ss = [slice(None) for i in range(fid.ndim)]
ss[axis] = slice(0,1)
ss = tuple(ss)
fid[ss] *= 2
return fid
def calculateAxes(bandwidth,centralFrequency,points):
......
......@@ -291,8 +291,13 @@ class FitRes(object):
# Extract concentrations from parameters.
if metab is not None:
if metab not in self.metabs:
raise ValueError(f'{metab} is not a recognised metabolite.')
if isinstance(metab,list):
for mm in metab:
if mm not in self.metabs:
raise ValueError(f'{mm} is not a recognised metabolite.')
else:
if metab not in self.metabs:
raise ValueError(f'{metab} is not a recognised metabolite.')
rawConc = dfFunc(metab)
else:
rawConc = dfFunc(self.metabs)
......@@ -333,11 +338,11 @@ class FitRes(object):
raise ValueError('phi0 must be degrees or radians')
if phi1.lower() == 'seconds':
p1 *= 1/(2*np.pi)
p1 *= -1.0/(2*np.pi)
elif phi1.lower() == 'deg_per_ppm':
p1 *= 180.0/np.pi * self.hzperppm
p1 *= -180.0/np.pi * self.hzperppm
elif phi1.lower() == 'deg_per_hz':
p1 *= 180.0/np.pi * 1.0
p1 *= -180.0/np.pi * 1.0
else:
raise ValueError('phi1 must be seconds or deg_per_ppm or deg_per_hz ')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment