Commit 6c9dd697 authored by William Clarke's avatar William Clarke
Browse files

Merge branch 'dynamic_editing' into 'master'

ENH: Dynamic editing enabled by allowing multiple basis sets.

See merge request fsl/fsl_mrs!28
parents d256944c b19f52d2
...@@ -4,6 +4,8 @@ This document contains the FSL-MRS release history in reverse chronological orde ...@@ -4,6 +4,8 @@ This document contains the FSL-MRS release history in reverse chronological orde
------------------------------- -------------------------------
- Fixed typos in fsl_mrs_proc help. - Fixed typos in fsl_mrs_proc help.
- Fixed simulator bug for edited sequence coherence filters. - Fixed simulator bug for edited sequence coherence filters.
- Modified API of syntheticFromBasis function.
- Dynamic fitting now handles multiple different basis sets.
1.1.8 (Tuesday 5th October 2021) 1.1.8 (Tuesday 5th October 2021)
------------------------------- -------------------------------
......
...@@ -8,7 +8,6 @@ from fsl_mrs.utils.synthetic import syntheticFID ...@@ -8,7 +8,6 @@ from fsl_mrs.utils.synthetic import syntheticFID
from fsl_mrs.utils.synthetic.synthetic_from_basis import syntheticFromBasisFile from fsl_mrs.utils.synthetic.synthetic_from_basis import syntheticFromBasisFile
from fsl_mrs.core import MRS from fsl_mrs.core import MRS
from fsl_mrs.utils.fitting import fit_FSLModel from fsl_mrs.utils.fitting import fit_FSLModel
from fsl_mrs.utils import mrs_io
from pytest import fixture from pytest import fixture
import numpy as np import numpy as np
...@@ -142,14 +141,12 @@ def test_fit_FSLModel_lorentzian_MH(data): ...@@ -142,14 +141,12 @@ def test_fit_FSLModel_lorentzian_MH(data):
def test_fit_FSLModel_on_invivo_sim(): def test_fit_FSLModel_on_invivo_sim():
FIDs, hdr, trueconcs = syntheticFromBasisFile(basis_path, FIDs, mrs, trueconcs = syntheticFromBasisFile(basis_path,
noisecovariance=[[1E-3]], noisecovariance=[[1E-3]],
broadening=(9.0, 9.0), broadening=(9.0, 9.0),
concentrations={'Mac': 2.0}) concentrations={'Mac': 2.0})
basis = mrs_io.read_basis(basis_path) mrs.FID = FIDs
mrs = MRS(FID=FIDs, header=hdr, basis=basis)
mrs.processForFitting() mrs.processForFitting()
metab_groups = [0] * mrs.numBasis metab_groups = [0] * mrs.numBasis
...@@ -164,6 +161,6 @@ def test_fit_FSLModel_on_invivo_sim(): ...@@ -164,6 +161,6 @@ def test_fit_FSLModel_on_invivo_sim():
fittedRelconcs = res.getConc(scaling='internal', metab=mrs.names) fittedRelconcs = res.getConc(scaling='internal', metab=mrs.names)
answers = np.asarray(trueconcs) answers = np.asarray(trueconcs)
answers /= (answers[basis.names.index('Cr')] + trueconcs[basis.names.index('PCr')]) answers /= (answers[mrs.names.index('Cr')] + trueconcs[mrs.names.index('PCr')])
assert np.allclose(fittedRelconcs, answers, atol=5E-2) assert np.allclose(fittedRelconcs, answers, atol=5E-2)
...@@ -7,7 +7,6 @@ Copyright Will Clarke, University of Oxford, 2021''' ...@@ -7,7 +7,6 @@ Copyright Will Clarke, University of Oxford, 2021'''
from fsl_mrs.utils import synthetic as syn from fsl_mrs.utils import synthetic as syn
from fsl_mrs.utils.misc import FIDToSpec from fsl_mrs.utils.misc import FIDToSpec
from fsl_mrs.utils import mrs_io from fsl_mrs.utils import mrs_io
from fsl_mrs.core import MRS
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
...@@ -61,33 +60,33 @@ def test_syntheticFID(): ...@@ -61,33 +60,33 @@ def test_syntheticFID():
def test_syntheticFromBasis(): def test_syntheticFromBasis():
fid, header, _ = syn.syntheticFromBasisFile(str(basis_path), fid, _, _ = syn.syntheticFromBasisFile(str(basis_path),
ignore=['Scyllo'], ignore=['Scyllo'],
baseline=[0.0, 0.0], baseline=[0.0, 0.0],
concentrations={'Mac': 2.0}, concentrations={'Mac': 2.0},
coilamps=[1.0, 1.0], coilamps=[1.0, 1.0],
coilphase=[0.0, np.pi], coilphase=[0.0, np.pi],
noisecovariance=[[0.1, 0.0], [0.0, 0.1]]) noisecovariance=[[0.1, 0.0], [0.0, 0.1]])
assert fid.shape == (2048, 2) assert fid.shape == (2048, 2)
def test_syntheticFromBasis_baseline(): def test_syntheticFromBasis_baseline():
fid, header, _ = syn.syntheticFromBasisFile(str(basis_path), fid, mrs, _ = syn.syntheticFromBasisFile(str(basis_path),
baseline=[0.0, 0.0], baseline=[0.0, 0.0],
concentrations={'Mac': 2.0}, concentrations={'Mac': 2.0},
noisecovariance=[[0.0]]) noisecovariance=[[0.0]])
mrs = MRS(FID=fid, header=header) mrs.FID = fid
mrs.conj_FID = True mrs.conj_FID = True
fid, header, _ = syn.syntheticFromBasisFile(str(basis_path), fid, mrs2, _ = syn.syntheticFromBasisFile(str(basis_path),
baseline=[1.0, 1.0], baseline=[1.0, 1.0],
concentrations={'Mac': 2.0}, concentrations={'Mac': 2.0},
noisecovariance=[[0.0]]) noisecovariance=[[0.0]])
mrs2 = MRS(FID=fid, header=header) mrs2.FID = fid
mrs2.conj_FID = True mrs2.conj_FID = True
assert np.allclose(mrs2.get_spec(), mrs.get_spec() + complex(1.0, -1.0)) assert np.allclose(mrs2.get_spec(), mrs.get_spec() + complex(1.0, -1.0))
......
...@@ -235,8 +235,8 @@ class dynRes: ...@@ -235,8 +235,8 @@ class dynRes:
def calc_fit_from_flatmapped(mapped): def calc_fit_from_flatmapped(mapped):
fwd = [] fwd = []
for mp in mapped: for idx, mp in enumerate(mapped):
fwd.append(self._dyn.forward(mp)) fwd.append(self._dyn.forward[idx](mp))
return np.asarray(fwd) return np.asarray(fwd)
init_fit = calc_fit_from_flatmapped(self.mapped_parameters_init) init_fit = calc_fit_from_flatmapped(self.mapped_parameters_init)
......
...@@ -33,7 +33,8 @@ class dynMRS(object): ...@@ -33,7 +33,8 @@ class dynMRS(object):
model='voigt', model='voigt',
ppmlim=(.2, 4.2), ppmlim=(.2, 4.2),
baseline_order=2, baseline_order=2,
metab_groups=None): metab_groups=None,
rescale=True):
"""Create a dynMRS class object """Create a dynMRS class object
:param mrs_list: List of MRS objects, one per time_var :param mrs_list: List of MRS objects, one per time_var
...@@ -50,25 +51,27 @@ class dynMRS(object): ...@@ -50,25 +51,27 @@ class dynMRS(object):
:type baseline_order: int, optional :type baseline_order: int, optional
:param metab_groups: Metabolite group list, defaults to None :param metab_groups: Metabolite group list, defaults to None
:type metab_groups: list, optional :type metab_groups: list, optional
:param rescale: Apply basis and FID rescaling, defaults to True
:type rescale: bool, optional
""" """
self.time_var = time_var self.time_var = time_var
self.mrs_list = mrs_list self.mrs_list = mrs_list
self._process_mrs_list() if rescale:
self._process_mrs_list()
if metab_groups is None: if metab_groups is None:
metab_groups = [0] * len(self.metabolite_names) metab_groups = [0] * len(self.metabolite_names)
self.data = self._prepare_data(ppmlim)
self.constants = self._get_constants(model, ppmlim, baseline_order, metab_groups)
self.forward = self._get_forward(model)
self.gradient = self._get_gradient(model)
self._fit_args = {'model': model, self._fit_args = {'model': model,
'baseline_order': baseline_order,
'metab_groups': metab_groups, 'metab_groups': metab_groups,
'baseline_order': baseline_order, 'baseline_order': baseline_order,
'ppmlim': ppmlim} 'ppmlim': ppmlim}
self.data = self._prepare_data(ppmlim)
self.forward = self._get_forward()
self.gradient = self._get_gradient()
numBasis, numGroups = self.mrs_list[0].numBasis, max(metab_groups) + 1 numBasis, numGroups = self.mrs_list[0].numBasis, max(metab_groups) + 1
varNames, varSizes = models.FSLModel_vars(model, numBasis, numGroups, baseline_order) varNames, varSizes = models.FSLModel_vars(model, numBasis, numGroups, baseline_order)
self.vm = self._create_vm(model, config_file, varNames, varSizes) self.vm = self._create_vm(model, config_file, varNames, varSizes)
...@@ -217,9 +220,8 @@ class dynMRS(object): ...@@ -217,9 +220,8 @@ class dynMRS(object):
return {'x': init, 'resList': resList} return {'x': init, 'resList': resList}
# Utility methods # Utility methods
def _get_constants(self, model, ppmlim, baseline_order, metab_groups): def _get_constants(self, mrs, ppmlim, baseline_order, metab_groups):
"""collect constants for forward model""" """collect constants for forward model"""
mrs = self.mrs_list[0]
first, last = mrs.ppmlim_to_range(ppmlim) # data range first, last = mrs.ppmlim_to_range(ppmlim) # data range
freq, time, basis = mrs.frequencyAxis, mrs.timeAxis, mrs.basis freq, time, basis = mrs.frequencyAxis, mrs.timeAxis, mrs.basis
base_poly = fitting.prepare_baseline_regressor(mrs, baseline_order, ppmlim) base_poly = fitting.prepare_baseline_regressor(mrs, baseline_order, ppmlim)
...@@ -241,28 +243,54 @@ class dynMRS(object): ...@@ -241,28 +243,54 @@ class dynMRS(object):
data = [mrs.get_spec().copy()[first:last] for mrs in self.mrs_list] data = [mrs.get_spec().copy()[first:last] for mrs in self.mrs_list]
return data return data
def _get_forward(self, model): def _get_forward(self):
"""Get forward model""" """Get forward model"""
forward = models.getModelForward(model) fwd_lambdas = []
first, last = self.constants[-2:] for mrs in self.mrs_list:
return lambda x: forward(x, *self.constants[:-2])[first:last] forward = models.getModelForward(self._fit_args['model'])
constants = self._get_constants(
mrs,
self._fit_args['ppmlim'],
self._fit_args['baseline_order'],
self._fit_args['metab_groups'])
def raiser(const):
first, last = const[-2:]
return lambda x: forward(x, *const[:-2])[first:last]
fwd_lambdas.append(raiser(constants))
def _get_gradient(self, model): return fwd_lambdas
def _get_gradient(self):
"""Get gradient""" """Get gradient"""
gradient = models.getModelJac(model) gradient = models.getModelJac(self._fit_args['model'])
return lambda x: gradient(x, *self.constants) fwd_grads = []
for mrs in self.mrs_list:
constants = self._get_constants(
mrs,
self._fit_args['ppmlim'],
self._fit_args['baseline_order'],
self._fit_args['metab_groups'])
def raiser(const):
return lambda x: gradient(x, *const)
fwd_grads.append(raiser(constants))
return fwd_grads
# Loss functions # Loss functions
def loss(self, x, i): def loss(self, x, i):
"""Calc loss function""" """Calc loss function"""
loss_real = .5 * np.mean(np.real(self.forward(x) - self.data[i]) ** 2) loss_real = .5 * np.mean(np.real(self.forward[i](x) - self.data[i]) ** 2)
loss_imag = .5 * np.mean(np.imag(self.forward(x) - self.data[i]) ** 2) loss_imag = .5 * np.mean(np.imag(self.forward[i](x) - self.data[i]) ** 2)
return loss_real + loss_imag return loss_real + loss_imag
def loss_grad(self, x, i): def loss_grad(self, x, i):
"""Calc gradient of loss function""" """Calc gradient of loss function"""
g = self.gradient(x) g = self.gradient[i](x)
e = self.forward(x) - self.data[i] e = self.forward[i](x) - self.data[i]
grad_real = np.mean(np.real(g) * np.real(e[:, None]), axis=0) grad_real = np.mean(np.real(g) * np.real(e[:, None]), axis=0)
grad_imag = np.mean(np.imag(g) * np.imag(e[:, None]), axis=0) grad_imag = np.mean(np.imag(g) * np.imag(e[:, None]), axis=0)
return grad_real + grad_imag return grad_real + grad_imag
...@@ -326,12 +354,16 @@ class dynMRS(object): ...@@ -326,12 +354,16 @@ class dynMRS(object):
mapped = self.vm.free_to_mapped(x) mapped = self.vm.free_to_mapped(x)
for time_index in range(self.vm.ntimes): for time_index in range(self.vm.ntimes):
p = np.hstack(mapped[time_index, :]) p = np.hstack(mapped[time_index, :])
fwd[time_index, :] = self.forward(p) fwd[time_index, :] = self.forward[time_index](p)
return fwd.flatten() return fwd.flatten()
def _form_FitRes(self, x, model, method, ppmlim, baseline_order): def _form_FitRes(self, x, model, method, ppmlim, baseline_order):
"""Create list of FitRes object""" """Create list of FitRes object"""
_, _, _, base_poly, metab_groups, _, _, _ = self.constants _, _, _, base_poly, metab_groups, _, _, _ = self._get_constants(
self.mrs_list[0],
self._fit_args['ppmlim'],
self._fit_args['baseline_order'],
self._fit_args['metab_groups'])
if method.lower() == 'mh': if method.lower() == 'mh':
mapped = [] mapped = []
for xx in x: for xx in x:
......
...@@ -176,7 +176,6 @@ def FMRS(smooth=False, path='/Users/saad/Desktop/Spectroscopy/'): ...@@ -176,7 +176,6 @@ def FMRS(smooth=False, path='/Users/saad/Desktop/Spectroscopy/'):
basis, names, basisheader = mrs_io.read_basis(str(basisfile)) basis, names, basisheader = mrs_io.read_basis(str(basisfile))
# # Resample basis # # Resample basis
from fsl_mrs.utils import misc
basis = misc.ts_to_ts(basis, basis = misc.ts_to_ts(basis,
basisheader[0]['dwelltime'], basisheader[0]['dwelltime'],
FIDheader['dwelltime'], FIDheader['dwelltime'],
...@@ -214,33 +213,20 @@ def MPRESS(noise=1, path='/Users/saad/Desktop/Spectroscopy/'): ...@@ -214,33 +213,20 @@ def MPRESS(noise=1, path='/Users/saad/Desktop/Spectroscopy/'):
mrs Object list mrs Object list
list (time variable) list (time variable)
""" """
from fsl_mrs.utils import mrs_io
from fsl_mrs.utils.synthetic.synthetic_from_basis import syntheticFromBasisFile from fsl_mrs.utils.synthetic.synthetic_from_basis import syntheticFromBasisFile
path = Path(path) path = Path(path)
mpress_on = path / 'mpress_basis/ON' mpress_on = path / 'mpress_basis/ON'
mpress_off = path / 'mpress_basis/OFF' mpress_off = path / 'mpress_basis/OFF'
basis, names, basis_hdr = mrs_io.read_basis(mpress_on) FIDs, mrs1, conc = syntheticFromBasisFile(mpress_on, noisecovariance=[[noise]])
FIDs, header, conc = syntheticFromBasisFile(mpress_on, noisecovariance=[[noise]]) mrs1.FID = FIDs
mrs1 = MRS(FID=FIDs,
header=header,
basis=basis,
basis_hdr=basis_hdr[0],
names=names)
mrs1.check_FID(repair=True) mrs1.check_FID(repair=True)
mrs1.Spec = misc.FIDToSpec(mrs1.FID)
mrs1.check_Basis(repair=True) mrs1.check_Basis(repair=True)
basis, names, basis_hdr = mrs_io.read_basis(mpress_off) FIDs, mrs2, conc = syntheticFromBasisFile(mpress_off, noisecovariance=[[noise]])
FIDs, header, conc = syntheticFromBasisFile(mpress_off, noisecovariance=[[noise]]) mrs2.FID = FIDs
mrs2 = MRS(FID=FIDs,
header=header,
basis=basis,
basis_hdr=basis_hdr[0],
names=names)
mrs2.check_FID(repair=True) mrs2.check_FID(repair=True)
mrs2.Spec = misc.FIDToSpec(mrs2.FID)
mrs2.check_Basis(repair=True) mrs2.check_Basis(repair=True)
return [mrs1, mrs2], [0, 1] return [mrs1, mrs2], [0, 1]
...@@ -218,7 +218,7 @@ def syntheticFromBasisFile(basisFile, ...@@ -218,7 +218,7 @@ def syntheticFromBasisFile(basisFile,
ppmlim=baseline_ppm) ppmlim=baseline_ppm)
return FIDs, \ return FIDs, \
{'bandwidth': bandwidth, 'centralFrequency': empty_mrs.centralFrequency}, \ empty_mrs, \
concentrations concentrations
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment