Commit 6c9dd697 authored by William Clarke's avatar William Clarke
Browse files

Merge branch 'dynamic_editing' into 'master'

ENH: Dynamic editing enabled by allowing multiple basis sets.

See merge request !28
parents d256944c b19f52d2
......@@ -4,6 +4,8 @@ This document contains the FSL-MRS release history in reverse chronological orde
-------------------------------
- Fixed typos in fsl_mrs_proc help.
- Fixed simulator bug for edited sequence coherence filters.
- Modified API of syntheticFromBasis function.
- Dynamic fitting now handles multiple different basis sets.
1.1.8 (Tuesday 5th October 2021)
-------------------------------
......
......@@ -8,7 +8,6 @@ from fsl_mrs.utils.synthetic import syntheticFID
from fsl_mrs.utils.synthetic.synthetic_from_basis import syntheticFromBasisFile
from fsl_mrs.core import MRS
from fsl_mrs.utils.fitting import fit_FSLModel
from fsl_mrs.utils import mrs_io
from pytest import fixture
import numpy as np
......@@ -142,14 +141,12 @@ def test_fit_FSLModel_lorentzian_MH(data):
def test_fit_FSLModel_on_invivo_sim():
FIDs, hdr, trueconcs = syntheticFromBasisFile(basis_path,
FIDs, mrs, trueconcs = syntheticFromBasisFile(basis_path,
noisecovariance=[[1E-3]],
broadening=(9.0, 9.0),
concentrations={'Mac': 2.0})
basis = mrs_io.read_basis(basis_path)
mrs = MRS(FID=FIDs, header=hdr, basis=basis)
mrs.FID = FIDs
mrs.processForFitting()
metab_groups = [0] * mrs.numBasis
......@@ -164,6 +161,6 @@ def test_fit_FSLModel_on_invivo_sim():
fittedRelconcs = res.getConc(scaling='internal', metab=mrs.names)
answers = np.asarray(trueconcs)
answers /= (answers[basis.names.index('Cr')] + trueconcs[basis.names.index('PCr')])
answers /= (answers[mrs.names.index('Cr')] + trueconcs[mrs.names.index('PCr')])
assert np.allclose(fittedRelconcs, answers, atol=5E-2)
......@@ -7,7 +7,6 @@ Copyright Will Clarke, University of Oxford, 2021'''
from fsl_mrs.utils import synthetic as syn
from fsl_mrs.utils.misc import FIDToSpec
from fsl_mrs.utils import mrs_io
from fsl_mrs.core import MRS
import numpy as np
from pathlib import Path
......@@ -61,33 +60,33 @@ def test_syntheticFID():
def test_syntheticFromBasis():
fid, header, _ = syn.syntheticFromBasisFile(str(basis_path),
ignore=['Scyllo'],
baseline=[0.0, 0.0],
concentrations={'Mac': 2.0},
coilamps=[1.0, 1.0],
coilphase=[0.0, np.pi],
noisecovariance=[[0.1, 0.0], [0.0, 0.1]])
fid, _, _ = syn.syntheticFromBasisFile(str(basis_path),
ignore=['Scyllo'],
baseline=[0.0, 0.0],
concentrations={'Mac': 2.0},
coilamps=[1.0, 1.0],
coilphase=[0.0, np.pi],
noisecovariance=[[0.1, 0.0], [0.0, 0.1]])
assert fid.shape == (2048, 2)
def test_syntheticFromBasis_baseline():
fid, header, _ = syn.syntheticFromBasisFile(str(basis_path),
baseline=[0.0, 0.0],
concentrations={'Mac': 2.0},
noisecovariance=[[0.0]])
fid, mrs, _ = syn.syntheticFromBasisFile(str(basis_path),
baseline=[0.0, 0.0],
concentrations={'Mac': 2.0},
noisecovariance=[[0.0]])
mrs = MRS(FID=fid, header=header)
mrs.FID = fid
mrs.conj_FID = True
fid, header, _ = syn.syntheticFromBasisFile(str(basis_path),
baseline=[1.0, 1.0],
concentrations={'Mac': 2.0},
noisecovariance=[[0.0]])
fid, mrs2, _ = syn.syntheticFromBasisFile(str(basis_path),
baseline=[1.0, 1.0],
concentrations={'Mac': 2.0},
noisecovariance=[[0.0]])
mrs2 = MRS(FID=fid, header=header)
mrs2.FID = fid
mrs2.conj_FID = True
assert np.allclose(mrs2.get_spec(), mrs.get_spec() + complex(1.0, -1.0))
......
......@@ -235,8 +235,8 @@ class dynRes:
def calc_fit_from_flatmapped(mapped):
fwd = []
for mp in mapped:
fwd.append(self._dyn.forward(mp))
for idx, mp in enumerate(mapped):
fwd.append(self._dyn.forward[idx](mp))
return np.asarray(fwd)
init_fit = calc_fit_from_flatmapped(self.mapped_parameters_init)
......
......@@ -33,7 +33,8 @@ class dynMRS(object):
model='voigt',
ppmlim=(.2, 4.2),
baseline_order=2,
metab_groups=None):
metab_groups=None,
rescale=True):
"""Create a dynMRS class object
:param mrs_list: List of MRS objects, one per time_var
......@@ -50,25 +51,27 @@ class dynMRS(object):
:type baseline_order: int, optional
:param metab_groups: Metabolite group list, defaults to None
:type metab_groups: list, optional
:param rescale: Apply basis and FID rescaling, defaults to True
:type rescale: bool, optional
"""
self.time_var = time_var
self.mrs_list = mrs_list
self._process_mrs_list()
if rescale:
self._process_mrs_list()
if metab_groups is None:
metab_groups = [0] * len(self.metabolite_names)
self.data = self._prepare_data(ppmlim)
self.constants = self._get_constants(model, ppmlim, baseline_order, metab_groups)
self.forward = self._get_forward(model)
self.gradient = self._get_gradient(model)
self._fit_args = {'model': model,
'baseline_order': baseline_order,
'metab_groups': metab_groups,
'baseline_order': baseline_order,
'ppmlim': ppmlim}
self.data = self._prepare_data(ppmlim)
self.forward = self._get_forward()
self.gradient = self._get_gradient()
numBasis, numGroups = self.mrs_list[0].numBasis, max(metab_groups) + 1
varNames, varSizes = models.FSLModel_vars(model, numBasis, numGroups, baseline_order)
self.vm = self._create_vm(model, config_file, varNames, varSizes)
......@@ -217,9 +220,8 @@ class dynMRS(object):
return {'x': init, 'resList': resList}
# Utility methods
def _get_constants(self, model, ppmlim, baseline_order, metab_groups):
def _get_constants(self, mrs, ppmlim, baseline_order, metab_groups):
"""collect constants for forward model"""
mrs = self.mrs_list[0]
first, last = mrs.ppmlim_to_range(ppmlim) # data range
freq, time, basis = mrs.frequencyAxis, mrs.timeAxis, mrs.basis
base_poly = fitting.prepare_baseline_regressor(mrs, baseline_order, ppmlim)
......@@ -241,28 +243,54 @@ class dynMRS(object):
data = [mrs.get_spec().copy()[first:last] for mrs in self.mrs_list]
return data
def _get_forward(self, model):
def _get_forward(self):
"""Get forward model"""
forward = models.getModelForward(model)
first, last = self.constants[-2:]
return lambda x: forward(x, *self.constants[:-2])[first:last]
fwd_lambdas = []
for mrs in self.mrs_list:
forward = models.getModelForward(self._fit_args['model'])
constants = self._get_constants(
mrs,
self._fit_args['ppmlim'],
self._fit_args['baseline_order'],
self._fit_args['metab_groups'])
def raiser(const):
first, last = const[-2:]
return lambda x: forward(x, *const[:-2])[first:last]
fwd_lambdas.append(raiser(constants))
def _get_gradient(self, model):
return fwd_lambdas
def _get_gradient(self):
"""Get gradient"""
gradient = models.getModelJac(model)
return lambda x: gradient(x, *self.constants)
gradient = models.getModelJac(self._fit_args['model'])
fwd_grads = []
for mrs in self.mrs_list:
constants = self._get_constants(
mrs,
self._fit_args['ppmlim'],
self._fit_args['baseline_order'],
self._fit_args['metab_groups'])
def raiser(const):
return lambda x: gradient(x, *const)
fwd_grads.append(raiser(constants))
return fwd_grads
# Loss functions
def loss(self, x, i):
"""Calc loss function"""
loss_real = .5 * np.mean(np.real(self.forward(x) - self.data[i]) ** 2)
loss_imag = .5 * np.mean(np.imag(self.forward(x) - self.data[i]) ** 2)
loss_real = .5 * np.mean(np.real(self.forward[i](x) - self.data[i]) ** 2)
loss_imag = .5 * np.mean(np.imag(self.forward[i](x) - self.data[i]) ** 2)
return loss_real + loss_imag
def loss_grad(self, x, i):
"""Calc gradient of loss function"""
g = self.gradient(x)
e = self.forward(x) - self.data[i]
g = self.gradient[i](x)
e = self.forward[i](x) - self.data[i]
grad_real = np.mean(np.real(g) * np.real(e[:, None]), axis=0)
grad_imag = np.mean(np.imag(g) * np.imag(e[:, None]), axis=0)
return grad_real + grad_imag
......@@ -326,12 +354,16 @@ class dynMRS(object):
mapped = self.vm.free_to_mapped(x)
for time_index in range(self.vm.ntimes):
p = np.hstack(mapped[time_index, :])
fwd[time_index, :] = self.forward(p)
fwd[time_index, :] = self.forward[time_index](p)
return fwd.flatten()
def _form_FitRes(self, x, model, method, ppmlim, baseline_order):
"""Create list of FitRes object"""
_, _, _, base_poly, metab_groups, _, _, _ = self.constants
_, _, _, base_poly, metab_groups, _, _, _ = self._get_constants(
self.mrs_list[0],
self._fit_args['ppmlim'],
self._fit_args['baseline_order'],
self._fit_args['metab_groups'])
if method.lower() == 'mh':
mapped = []
for xx in x:
......
......@@ -176,7 +176,6 @@ def FMRS(smooth=False, path='/Users/saad/Desktop/Spectroscopy/'):
basis, names, basisheader = mrs_io.read_basis(str(basisfile))
# # Resample basis
from fsl_mrs.utils import misc
basis = misc.ts_to_ts(basis,
basisheader[0]['dwelltime'],
FIDheader['dwelltime'],
......@@ -214,33 +213,20 @@ def MPRESS(noise=1, path='/Users/saad/Desktop/Spectroscopy/'):
mrs Object list
list (time variable)
"""
from fsl_mrs.utils import mrs_io
from fsl_mrs.utils.synthetic.synthetic_from_basis import syntheticFromBasisFile
path = Path(path)
mpress_on = path / 'mpress_basis/ON'
mpress_off = path / 'mpress_basis/OFF'
basis, names, basis_hdr = mrs_io.read_basis(mpress_on)
FIDs, header, conc = syntheticFromBasisFile(mpress_on, noisecovariance=[[noise]])
mrs1 = MRS(FID=FIDs,
header=header,
basis=basis,
basis_hdr=basis_hdr[0],
names=names)
FIDs, mrs1, conc = syntheticFromBasisFile(mpress_on, noisecovariance=[[noise]])
mrs1.FID = FIDs
mrs1.check_FID(repair=True)
mrs1.Spec = misc.FIDToSpec(mrs1.FID)
mrs1.check_Basis(repair=True)
basis, names, basis_hdr = mrs_io.read_basis(mpress_off)
FIDs, header, conc = syntheticFromBasisFile(mpress_off, noisecovariance=[[noise]])
mrs2 = MRS(FID=FIDs,
header=header,
basis=basis,
basis_hdr=basis_hdr[0],
names=names)
FIDs, mrs2, conc = syntheticFromBasisFile(mpress_off, noisecovariance=[[noise]])
mrs2.FID = FIDs
mrs2.check_FID(repair=True)
mrs2.Spec = misc.FIDToSpec(mrs2.FID)
mrs2.check_Basis(repair=True)
return [mrs1, mrs2], [0, 1]
......@@ -218,7 +218,7 @@ def syntheticFromBasisFile(basisFile,
ppmlim=baseline_ppm)
return FIDs, \
{'bandwidth': bandwidth, 'centralFrequency': empty_mrs.centralFrequency}, \
empty_mrs, \
concentrations
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment