Commit 1051ff99 authored by William Clarke's avatar William Clarke
Browse files

Merge branch 'newton_dynamic'

parents 762b292a 6379d358
from fsl_mrs.utils.dynamic.variable_mapping import VariableMapping
from fsl_mrs.utils.dynamic.dynmrs import dynMRS
\ No newline at end of file
# dynmrs.py - Class responsible for dynMRS fitting
#
# Author: Saad Jbabdi <saad@fmrib.ox.ac.uk>
# William Clarke <william.clarke@ndcn.ox.ac.uk>
#
# Copyright (C) 2019 University of Oxford
# SHBASECOPYRIGHT
import numpy as np
from scipy.optimize import minimize
import time
from fsl_mrs.utils import models, fitting
from . import variable_mapping as varmap
from fsl_mrs.utils.results import FitRes
from fsl_mrs.utils.stats import mh, dist
class dynMRS(object):
"""Dynamic MRS class"""
def __init__(self, mrs_list, time_var):
"""
mrs_list : list of MRS objects
time_var : array-like
"""
self.mrs_list = mrs_list
self.time_var = time_var
self.data = None
self.constants = None
self.forward = None
self.gradient = None
self.vm = None
def fit(self,
config_file,
method='Newton',
model='voigt',
ppmlim=(.2, 4.2),
baseline_order=2,
metab_groups=None,
init=None,
verbose=False):
"""
Fit dynamic MRS model
Parameters
----------
config_file : string
method : string ('Newton' or 'MH')
model : string ('voigt' or 'lorentzian')
ppmlim : tuple
baseline_order : int
metab_groups : array-like
init : dynMRSres object
verbose : bool
Returns
-------
dynMRSres object
"""
if verbose:
print('Start fitting')
start_time = time.time()
if metab_groups is None:
metab_groups = [0] * len(self.mrs_list[0].names)
self.data = self.prepare_data(ppmlim)
self.constants = self.get_constants(model, ppmlim, baseline_order, metab_groups)
self.forward = self.get_forward(model)
self.gradient = self.get_gradient(model)
numBasis, numGroups = self.mrs_list[0].numBasis, max(metab_groups) + 1
varNames, varSizes = models.FSLModel_vars(model, numBasis, numGroups, baseline_order)
self.vm = self.create_vm(model, config_file, varNames, varSizes)
bounds = self.vm.Bounds
if init is None:
init = self.initialise(model, metab_groups, ppmlim, baseline_order, verbose)
x0 = self.vm.mapped_to_free(init['x'])
# MCMC or Newton
if method.lower() == 'newton':
sol = minimize(fun=self.dyn_loss, x0=x0, jac=self.dyn_loss_grad, method='TNC', bounds=bounds)
x = sol.x
elif method.lower() == 'mh':
self.prior_means = np.zeros_like(self.vm.nfree)
self.prior_stds = np.ones_like(self.vm.nfree) * 1E3
mcmc = mh.MH(self.dyn_loglik, self.dyn_logpr, burnin=10, njumps=300, sampleevery=1)
LB, UB = mcmc.bounds_from_list(self.vm.nfree, self.vm.Bounds)
samples = mcmc.fit(x0, LB=LB, UB=UB, verbose=verbose)
x = np.mean(samples, axis=0)
else:
raise (Exception(f'Unrecognised method {method}'))
res_list = self.collect_results(x, model, method, ppmlim, baseline_order)
if verbose:
print(f"Fitting completed in {time.time()-start_time} seconds.")
return {'x': x, 'resList': res_list}
def get_constants(self, model, ppmlim, baseline_order, metab_groups):
"""collect constants for forward model"""
mrs = self.mrs_list[0]
first, last = mrs.ppmlim_to_range(ppmlim) # data range
freq, time, basis = mrs.frequencyAxis, mrs.timeAxis, mrs.basis
base_poly = fitting.prepare_baseline_regressor(mrs, baseline_order, ppmlim)
freq, time, basis = mrs.frequencyAxis, mrs.timeAxis, mrs.basis
g = max(metab_groups) + 1
return (freq, time, basis, base_poly, metab_groups, g, first, last)
def initialise(self, model, metab_groups, ppmlim, baseline_order, verbose=False):
"""initialise fitting"""
if verbose:
start_time = time.time()
FitArgs = {'model': model,
'metab_groups': metab_groups,
'ppmlim': ppmlim,
'method': 'Newton',
'baseline_order': baseline_order}
varNames = models.FSLModel_vars(model)
numMetabs = self.mrs_list[0].numBasis
numGroups = max(metab_groups) + 1
if FitArgs['model'] == 'lorentzian':
x2p = models.FSLModel_x2param
else:
x2p = models.FSLModel_x2param_Voigt
# Get init from fitting to individual time points
init = np.empty((len(self.time_var), len(varNames)), dtype=object)
resList = []
for t, mrs in enumerate(self.mrs_list):
if verbose:
print(f'Initialising {t + 1}/{len(self.mrs_list)}', end='\r')
res = fitting.fit_FSLModel(mrs, **FitArgs)
resList.append(res)
params = x2p(res.params, numMetabs, numGroups)
for i, p in enumerate(params):
init[t, i] = p
if verbose:
print(f'Init done in {time.time()-start_time} seconds.')
return {'x': init, 'resList': resList}
def create_vm(self, model, config_file, varNames, varSizes):
"""Create Variable Mapping object"""
vm = varmap.VariableMapping(param_names=varNames,
param_sizes=varSizes,
time_variable=self.time_var,
config_file=config_file)
return vm
def prepare_data(self, ppmlim):
"""FID to Spec and slice for fitting"""
first, last = self.mrs_list[0].ppmlim_to_range(ppmlim)
data = [mrs.get_spec().copy()[first:last] for mrs in self.mrs_list]
return data
def get_forward(self, model):
"""Get forward model"""
forward = models.getModelForward(model)
first, last = self.constants[-2:]
return lambda x: forward(x, *self.constants[:-2])[first:last]
def get_gradient(self, model):
"""Get gradient"""
gradient = models.getModelJac(model)
return lambda x: gradient(x, *self.constants)
def loss(self, x, i):
"""Calc loss function"""
loss_real = .5 * np.mean(np.real(self.forward(x) - self.data[i]) ** 2)
loss_imag = .5 * np.mean(np.imag(self.forward(x) - self.data[i]) ** 2)
return loss_real + loss_imag
def loss_grad(self, x, i):
"""Calc gradient of loss function"""
g = self.gradient(x)
e = self.forward(x) - self.data[i]
grad_real = np.mean(np.real(g) * np.real(e[:, None]), axis=0)
grad_imag = np.mean(np.imag(g) * np.imag(e[:, None]), axis=0)
return grad_real + grad_imag
def dyn_loss(self, x):
"""Add loss functions across data list"""
ret = 0
mapped = self.vm.free_to_mapped(x)
for time_index in range(len(self.vm.time_variable)):
p = np.hstack(mapped[time_index, :])
ret += self.loss(p, time_index)
return ret
def dyn_loss_grad(self, x):
"""Add gradients across data list"""
g = []
mapped = self.vm.free_to_mapped(x)
LUT = self.vm.free_to_mapped(np.arange(self.vm.nfree), copy_only=True)
dfdx = 0
for time_index, time_var in enumerate(self.vm.time_variable):
# dfdmapped
p = np.hstack(mapped[time_index, :])
dfdp = self.loss_grad(p, time_index)
# dmappeddfree
dpdx = []
for param_index, param in enumerate(self.vm.mapped_names):
grad_fcn = self.vm.get_gradient_fcn(param)
nparams = self.vm.mapped_sizes[param_index]
xindex = LUT[time_index, param_index]
for ip in range(nparams):
gg = np.zeros(self.vm.nfree)
gg[xindex[ip]] = grad_fcn(x[xindex[ip]], time_var)
dpdx.append(gg)
dpdx = np.asarray(dpdx)
dfdx += np.matmul(dfdp, dpdx)
return dfdx
def dyn_loglik(self, x):
"""neg log likelihood for MCMC"""
ll = 0.0
mapped = self.vm.free_to_mapped(x)
n_over_2 = len(self.data[0]) / 2
for time_index in range(len(self.vm.time_variable)):
p = np.hstack(mapped[time_index, :])
pred = self.forward(p)
ll += np.log(np.linalg.norm(pred - self.data[time_index])) * n_over_2
return ll
def dyn_logpr(self, p):
"""neg log prior for MCMC"""
return np.sum(dist.gauss_logpdf(p, loc=self.prior_means, scale=self.prior_stds))
# collect results
def collect_results(self, x, model, method, ppmlim, baseline_order):
"""Create list of FitRes object"""
_, _, _, base_poly, metab_groups, _, _, _ = self.constants
mapped = self.vm.free_to_mapped(x)
dynresList = []
for t in range(self.vm.ntimes):
mrs = self.mrs_list[t]
results = FitRes(model,
method,
mrs.names,
metab_groups,
baseline_order,
base_poly,
ppmlim)
results.loadResults(mrs, np.hstack(mapped[t]))
dynresList.append(results)
return dynresList
......@@ -20,7 +20,7 @@ class VariableMapping(object):
"""
Variable Mapping Class Constructor
Mapping betwee free and mapped:
Mapping between free and mapped:
Mapped = TxN matrix
Mapped[i,j] = float or 1D-array of floats with size param_sizes[j]
......@@ -79,6 +79,8 @@ class VariableMapping(object):
def __repr__(self) -> str:
return str(self)
def calc_nfree(self):
"""
Calculate number of free parameters based on mapped behaviour
......@@ -183,7 +185,7 @@ class VariableMapping(object):
name = [f'{param}_{x}' for x in range(self.mapped_sizes[index])]
names.extend(name)
elif (beh == 'variable'):
name = [f'{param}_{x}_t{t}' for x in range(self.mapped_sizes[index]) for t in range(self.ntimes)]
name = [f'{param}_{x}_t{t}' for t in range(self.ntimes) for x in range(self.mapped_sizes[index])]
names.extend(name)
else:
if 'dynamic' in beh:
......@@ -193,7 +195,7 @@ class VariableMapping(object):
return names
def free_to_mapped(self, p):
def free_to_mapped(self, p, copy_only=False):
"""
Convert free into mapped params over time
fixed params get copied over time domain
......@@ -203,7 +205,7 @@ class VariableMapping(object):
Parameters
----------
p : 1D array
copy_only : bool (copy params - don't use dynamic models)
Returns
-------
2D array (time X params)
......@@ -228,7 +230,7 @@ class VariableMapping(object):
elif (self.Parameters[name] == 'variable'): # copy one param for each time point
for t in range(self.ntimes):
mapped_params[t, index] = p[counter + t * nmapped:counter + t * nmapped + nmapped]
mapped_params[t, index] = p[counter :counter + nmapped]
counter += nmapped
else:
......@@ -237,15 +239,23 @@ class VariableMapping(object):
func_name = self.Parameters[name]['dynamic']
nfree = len(self.Parameters[name]['params'])
mapped = np.zeros((self.ntimes, nmapped))
for i in range(nmapped):
params = p[counter:counter + nfree]
mapped[:, i] = self.fcns[func_name](params, self.time_variable)
counter += nfree
for t in range(self.ntimes):
mapped_params[t, index] = mapped[t, :]
if not copy_only:
mapped = np.zeros((self.ntimes, nmapped))
for i in range(nmapped):
params = p[counter:counter + nfree]
mapped[:, i] = self.fcns[func_name](params, self.time_variable)
counter += nfree
for t in range(self.ntimes):
mapped_params[t, index] = mapped[t, :]
else:
mapped = np.empty((self.ntimes, nmapped),dtype=object)
for i in range(nmapped):
params = p[counter:counter + nfree]
for t in range(self.ntimes):
mapped[t, i] = params
counter += nfree
for t in range(self.ntimes):
mapped_params[t, index] = mapped[t, :]
else:
raise(Exception("Unknown Parameter type - should be one of 'fixed', 'variable', {'dynamic'}"))
......@@ -329,3 +339,21 @@ class VariableMapping(object):
raise(Exception("Unknown Parameter type - should be one of 'fixed', 'variable', {'dynamic'}"))
return free_params
def get_gradient_fcn(self,param_name):
"""
Get the gradient function for a given parameter
Returns:
function
"""
if (self.Parameters[param_name] == 'fixed') or (self.Parameters[param_name] == 'variable'):
return lambda x, t: 1
else:
if 'dynamic' in self.Parameters[param_name]:
func_name = self.Parameters[param_name]['dynamic']
grad_name = func_name + '_grad'
if grad_name not in self.fcns:
raise (Exception(f"Could not find gradient for parameter {param_name}"))
return self.fcns[grad_name]
else:
raise (Exception("Unknown Parameter type - should be one of 'fixed', 'variable', {'dynamic'}"))
......@@ -151,6 +151,65 @@ def FSLModel_err(x,nu,t,m,B,G,g,data,first,last):
return sse
def FSLModel_forward_jac(x, nu, t, m, B, G, g, data, first, last):
"""
x = [con[0],...,con[n-1],gamma,eps,phi0,phi1,baselineparams]
nu : array-like - frequency axis
t : array-like - time axis
m : basis time course
B : baseline functions
G : metabolite groups
g : number of metab groups
data : array like - frequency domain data
first,last : range for the fitting is data[first:last]
returns jacobian matrix
"""
n = m.shape[1] # get number of basis functions
# g = max(G)+1 # get number of metabolite groups
con, gamma, eps, phi0, phi1, b = FSLModel_x2param(x, n, g)
# Start
E = np.zeros((m.shape[0], g), dtype=np.complex)
for gg in range(g):
E[:, gg] = np.exp(-(1j * eps[gg] + gamma[gg]) * t).flatten()
e_term = np.zeros(m.shape, dtype=np.complex)
c = np.zeros((con.size, g))
for i, gg in enumerate(G):
e_term[:, i] = E[:, gg]
c[i, gg] = con[i]
m_term = m * e_term
phi_term = np.exp(-1j * (phi0 + phi1 * nu))
Fmet = FIDToSpec(m_term)
Ftmet = FIDToSpec(t * m_term)
Ftmetc = Ftmet @ c
Fmetcon = Fmet @ con[:, None]
# Gradients
dSdc = phi_term * Fmet
dSdgamma = phi_term * (-Ftmetc)
dSdeps = phi_term * (-1j * Ftmetc)
dSdphi0 = -1j * phi_term * (Fmetcon)
dSdphi1 = -1j * nu * phi_term * (Fmetcon)
dSdb = B
# Only compute within a range
dSdc = dSdc[first:last, :]
dSdgamma = dSdgamma[first:last, :]
dSdeps = dSdeps[first:last, :]
dSdphi0 = dSdphi0[first:last]
dSdphi1 = dSdphi1[first:last]
dSdb = dSdb[first:last]
jac = np.concatenate((dSdc, dSdgamma, dSdeps, dSdphi0, dSdphi1, dSdb), axis=1)
return jac
def FSLModel_forward_and_jac(x,nu,t,m,B,G,g,data,first,last):
"""
......@@ -291,7 +350,7 @@ def FSLModel_forward_Voigt(x,nu,t,m,B,G,g):
Returns forward prediction in the frequency domain
"""
n = m.shape[1] # get number of basis functions
con,gamma,sigma,eps,phi0,phi1,b = FSLModel_x2param_Voigt(x,n,g)
......@@ -411,6 +470,74 @@ def FSLModel_grad_Voigt(x,nu,t,m,B,G,g,data,first,last):
return grad
def FSLModel_jac_Voigt(x, nu, t, m, B, G, g, first, last):
"""
x = [con[0],...,con[n-1],gamma,eps,phi0,phi1,baselineparams]
nu : array-like - frequency axis
t : array-like - time axis
m : basis time course
B : baseline functions
G : metabolite groups
g : number of metab groups
data : array like - frequency domain data
first,last : range for the fitting is data[first:last]
returns gradient vector
"""
n = m.shape[1] # get number of basis functions
con, gamma, sigma, eps, phi0, phi1, b = FSLModel_x2param_Voigt(x, n, g)
# Start
E = np.zeros((m.shape[0], g), dtype=np.complex)
SIG = np.zeros((m.shape[0], g), dtype=np.complex)
for gg in range(g):
E[:, gg] = np.exp(-(1j * eps[gg] + gamma[gg] + t * sigma[gg] ** 2) * t).flatten()
SIG[:, gg] = sigma[gg]
e_term = np.zeros(m.shape, dtype=np.complex)
sig_term = np.zeros(m.shape, dtype=np.complex)
c = np.zeros((con.size, g))
for i, gg in enumerate(G):
e_term[:, i] = E[:, gg]
sig_term[:, i] = SIG[:, gg]
c[i, gg] = con[i]
m_term = m * e_term
phi_term = np.exp(-1j * (phi0 + phi1 * nu))
Fmet = FIDToSpec(m_term)
Ftmet = FIDToSpec(t * m_term)
Ft2sigmet = FIDToSpec(t * t * sig_term * m_term)
Ftmetc = Ftmet @ c
Ft2sigmetc = Ft2sigmet @ c
Fmetcon = Fmet @ con[:, None]
# Gradients
dSdc = phi_term * Fmet
dSdgamma = phi_term * (-Ftmetc)
dSdsigma = phi_term * (-2 * Ft2sigmetc)
dSdeps = phi_term * (-1j * Ftmetc)
dSdphi0 = -1j * phi_term * (Fmetcon)
dSdphi1 = -1j * nu * phi_term * (Fmetcon)
dSdb = B
# Only compute within a range
dSdc = dSdc[first:last, :]
dSdgamma = dSdgamma[first:last, :]
dSdsigma = dSdsigma[first:last, :]
dSdeps = dSdeps[first:last, :]
dSdphi0 = dSdphi0[first:last]
dSdphi1 = dSdphi1[first:last]
dSdb = dSdb[first:last]
dS = np.concatenate((dSdc, dSdgamma, dSdsigma, dSdeps, dSdphi0, dSdphi1, dSdb), axis=1)
return dS
def getModelFunctions(model):
""" Return the err, grad, forward and conversion functions appropriate for the model."""
if model == 'lorentzian':
......@@ -429,6 +556,26 @@ def getModelFunctions(model):
raise Exception('Unknown model {}.'.format(model))
return err_func,grad_func,forward,x2p,p2x
def getModelForward(model):
if model == 'lorentzian':
forward = FSLModel_forward # forward model
elif model == 'voigt':
forward = FSLModel_forward_Voigt # forward model
else:
raise Exception('Unknown model {}.'.format(model))
return forward
def getModelJac(model):
if model == 'lorentzian':
jac = FSLModel_jac
elif model == 'voigt':
jac = FSLModel_jac_Voigt
else:
raise Exception('Unknown model {}.'.format(model))
return jac
def getFittedModel(model,resParams,base_poly,metab_groups,mrs,basisSelect=None,baselineOnly = False,noBaseline = False):
""" Return the predicted model given some fitting parameters
......
......@@ -1042,6 +1042,77 @@ def plot_table_qc(res):
return fig
# ----------- Dyn MRS
# Visualisation
def plotly_dynMRS(mrs_list,time_var,ppmlim=(.2,4.2)):
"""
Plot dynamic MRS data with a slider though time
Args:
mrs_list: list of MRS objects
time_var: list of time variable (or bvals for dMRS)
ppmlim: list
Returns:
plotly Figure
"""
# Create figure
fig = go.Figure()
# Add traces, one for each slider step
for i, t in enumerate(time_var):
x = mrs_list[i].getAxes()
y = np.real(FIDToSpec(mrs_list[i].FID))
fig.add_trace(
go.Scatter(
visible=False,
line=dict(color="black", width=3),
name=f"{t}",
x=x,
y=y))
fig.update_layout(template='plotly_white')
fig.update_xaxes(title_text='Chemical shift (ppm)',
tick0=2, dtick=.5,
range=[ppmlim[1],ppmlim[0]])
# y-axis range
data = [np.real(FIDToSpec(mrs.FID)) for mrs in mrs_list]
data = np.asarray(data).flatten()
minval = np.min(data)
maxval = np.max(data)
ymin = minval - minval / 2
ymax = maxval + maxval / 30
fig.update_yaxes(zeroline=True,
zerolinewidth=1,
zerolinecolor='Gray',
showgrid=False, showticklabels=False,
range=[ymin, ymax])
# Make 0th trace visible
fig.data[0].visible = True
# Create and add slider
steps = []
for i in range(len(time_var)):
step = dict(
method="restyle",
label=f"t={time_var[i<