Commit e4266bb9 authored by William Clarke's avatar William Clarke
Browse files

Updates from working through NV's data and some dynamic fitting fixes.

parent 78712745
......@@ -13,6 +13,7 @@ from fsl_mrs.auxiliary import configargparse
from fsl_mrs import __version__
from fsl_mrs.utils.splash import splash
from os import makedirs
from shutil import rmtree
import os.path as op
from fsl_mrs.utils.preproc import nifti_mrs_proc as preproc
from fsl_mrs.core import NIFTI_MRS, is_nifti_mrs
......@@ -316,6 +317,9 @@ def main():
# Create output folder if required
if not op.isdir(args.output):
makedirs(args.output)
elif op.isdir(args.output) and args.overwrite:
rmtree(args.output)
makedirs(args.output)
# Handle report generation output location.
# Bit of a hack, but I messed up the type expected by the
......
......@@ -15,6 +15,7 @@ from fsl_mrs.utils import models, fitting
from . import variable_mapping as varmap
from fsl_mrs.utils.results import FitRes
from fsl_mrs.utils.stats import mh, dist
from fsl_mrs.utils.misc import calculate_lap_cov
class dynMRS(object):
......@@ -36,6 +37,7 @@ class dynMRS(object):
def fit(self,
config_file,
method='Newton',
mh_jumps=600,
model='voigt',
ppmlim=(.2, 4.2),
baseline_order=2,
......@@ -50,6 +52,7 @@ class dynMRS(object):
config_file : string
method : string ('Newton' or 'MH')
model : string ('voigt' or 'lorentzian')
mh_jumps : int
ppmlim : tuple
baseline_order : int
metab_groups : array-like
......@@ -82,21 +85,29 @@ class dynMRS(object):
# MCMC or Newton
if method.lower() == 'newton':
sol = minimize(fun=self.dyn_loss, x0=x0, jac=self.dyn_loss_grad, method='TNC', bounds=bounds)
# breakpoint()
# calculate covariance
data = np.asarray(self.data).flatten()
x_cov = calculate_lap_cov(sol.x, self.full_fwd, data)
x = sol.x
x_out = x
x_all = x
elif method.lower() == 'mh':
self.prior_means = np.zeros_like(self.vm.nfree)
self.prior_stds = np.ones_like(self.vm.nfree) * 1E3
mcmc = mh.MH(self.dyn_loglik, self.dyn_logpr, burnin=10, njumps=300, sampleevery=1)
mcmc = mh.MH(self.dyn_loglik, self.dyn_logpr, burnin=100, njumps=mh_jumps, sampleevery=5)
LB, UB = mcmc.bounds_from_list(self.vm.nfree, self.vm.Bounds)
samples = mcmc.fit(x0, LB=LB, UB=UB, verbose=verbose)
x = np.mean(samples, axis=0)
x = mcmc.fit(x0, LB=LB, UB=UB, verbose=verbose)
x_out = np.mean(x, axis=0)
x_all = x
x_cov = np.cov(x.T)
else:
raise (Exception(f'Unrecognised method {method}'))
res_list = self.collect_results(x, model, method, ppmlim, baseline_order)
if verbose:
print(f"Fitting completed in {time.time()-start_time} seconds.")
return {'x': x, 'resList': res_list}
return {'x': x_out, 'cov': x_cov, 'samples': x_all, 'resList': res_list}
def get_constants(self, model, ppmlim, baseline_order, metab_groups):
"""collect constants for forward model"""
......@@ -178,6 +189,15 @@ class dynMRS(object):
grad_imag = np.mean(np.imag(g) * np.imag(e[:, None]), axis=0)
return grad_real + grad_imag
def full_fwd(self, x):
'''Return flattened vector of the full estimated model'''
fwd = np.zeros((self.vm.ntimes, self.data[0].shape[0]), dtype=np.complex64)
mapped = self.vm.free_to_mapped(x)
for time_index in range(self.vm.ntimes):
p = np.hstack(mapped[time_index, :])
fwd[time_index, :] = self.forward(p)
return fwd.flatten()
def dyn_loss(self, x):
"""Add loss functions across data list"""
ret = 0
......@@ -231,7 +251,22 @@ class dynMRS(object):
def collect_results(self, x, model, method, ppmlim, baseline_order):
"""Create list of FitRes object"""
_, _, _, base_poly, metab_groups, _, _, _ = self.constants
mapped = self.vm.free_to_mapped(x)
if method.lower() == 'mh':
mapped = []
for xx in x:
tmp = self.vm.free_to_mapped(xx)
tmp_tmp = []
for tt in tmp:
tmp_tmp.append(np.hstack(tt))
mapped.append(np.asarray(tmp_tmp))
mapped = np.asarray(mapped)
mapped = np.moveaxis(mapped, 0, 1)
else:
tmp = self.vm.free_to_mapped(x)
tmp_tmp = []
for tt in tmp:
tmp_tmp.append(np.hstack(tt))
mapped = np.asarray(tmp_tmp)
dynresList = []
for t in range(self.vm.ntimes):
mrs = self.mrs_list[t]
......@@ -242,6 +277,6 @@ class dynMRS(object):
baseline_order,
base_poly,
ppmlim)
results.loadResults(mrs, np.hstack(mapped[t]))
results.loadResults(mrs, mapped[t])
dynresList.append(results)
return dynresList
......@@ -319,6 +319,7 @@ class VariableMapping(object):
func_name = self.Parameters[name]['dynamic']
time_var = self.time_variable
func = partial(self.fcns[func_name], t=time_var)
gradfunc = partial(self.get_gradient_fcn(name), t=time_var)
nfree = len(self.Parameters[name]['params'])
pp = np.stack(p[:, index][:], axis=0)
......@@ -326,10 +327,22 @@ class VariableMapping(object):
def loss(x):
pred = func(x)
return np.mean((pp[:, ppp] - pred)**2)
def loss_grad(x):
jac_out = []
S = func(x)
for ds in gradfunc(x):
jac_out.append(np.sum((2 * S * ds)
- (2 * pp[:, ppp] * ds)))
return np.asarray(jac_out)
bounds = self.Bounds[counter:counter + nfree]
vals = minimize(loss,
np.zeros(len(self.Parameters[name]['params'])),
method='TNC', bounds=bounds).x
jac=loss_grad,
method='TNC',
bounds=bounds).x
free_params[counter:counter + nfree] = vals
counter += nfree
......
......@@ -24,17 +24,16 @@ def simulated(ID=1):
basisfolder = fileDir / '../pkg_data/mrs_fitting_challenge/basisset_JMRUI'
# Load data and basis
FID, FIDheader = mrs_io.read_FID(str(datafolder / f'dataset{ID}_WS.txt'))
FIDW, _ = mrs_io.read_FID(str(datafolder / f'dataset{ID}_nWS.txt'))
FID = mrs_io.read_FID(str(datafolder / f'dataset{ID}_WS.txt'))
FIDW = mrs_io.read_FID(str(datafolder / f'dataset{ID}_nWS.txt'))
basis, names, Bheader = mrs_io.read_basis(basisfolder)
MRSArgs = {'header': FIDheader,
'basis': basis,
MRSArgs = {'basis': basis,
'names': names,
'basis_hdr': Bheader[0],
'H2O': FIDW}
mrs = MRS(FID=FID, **MRSArgs)
mrs = FID.mrs(**MRSArgs)
# Check orientation and rescale for extra robustness
mrs.processForFitting()
......@@ -134,12 +133,12 @@ def dMRS_SNR(avg=1, path='/Users/saad/Desktop/Spectroscopy/'):
bvals = [20, 3020, 6000, 10000, 20000, 30000, 50000]
MRSlist = []
for b in bvals:
FID, FIDheader = mrs_io.read_FID(str(FIDpath / f'b_{b:05}.nii.gz'))
MRSArgs = {'header': FIDheader,
'basis': basis,
FID = mrs_io.read_FID(str(FIDpath / f'b_{b:05}.nii.gz'))
MRSArgs = {'basis': basis,
'names': names,
'basis_hdr': Bheader[0]}
mrs = MRS(FID=FID, **MRSArgs)
mrs = FID.mrs(**MRSArgs)
MRSlist.append(mrs)
MRSlist[0].rescaleForFitting()
......
......@@ -12,7 +12,7 @@ import numpy as np
from fsl_mrs.utils.mrs_io import fsl_io as fsl
from fsl_mrs.utils.mrs_io import lcm_io as lcm
from fsl_mrs.utils.mrs_io import jmrui_io as jmrui
from fsl_mrs.core.nifti_mrs import NIFTI_MRS, NotNIFTI_MRS
from fsl_mrs.core import nifti_mrs # import NIFTI_MRS, NotNIFTI_MRS
import fsl.utils.path as fslpath
......@@ -34,8 +34,8 @@ def check_datatype(filename):
Returns one of: 'NIFTI_MRS', 'NIFTI','RAW','TXT','Unknown'
"""
try:
NIFTI_MRS(filename)
except (NotNIFTI_MRS, fslpath.PathError):
nifti_mrs.NIFTI_MRS(filename)
except (nifti_mrs.NotNIFTI_MRS, fslpath.PathError):
_, ext = filename.split(os.extsep, 1)
if ext.lower() == 'nii' or ext.lower() == 'nii.gz':
fsl.readNIFTI(filename)
......@@ -66,8 +66,8 @@ def read_FID(filename):
dict (header info)
"""
try:
return NIFTI_MRS(filename)
except (NotNIFTI_MRS, fslpath.PathError):
return nifti_mrs.NIFTI_MRS(filename)
except (nifti_mrs.NotNIFTI_MRS, fslpath.PathError):
data_type = check_datatype(filename)
if data_type == 'RAW':
......
......@@ -116,6 +116,10 @@ def phaseCorrect_report(inFID,
# re-estimate here.
position = np.argmax(np.abs(plotIn.get_spec(ppmlim=ppmlim)))
# Deal with rounding errors
if position >= len(plotIn.getAxes(ppmlim=ppmlim)):
position = len(plotIn.getAxes(ppmlim=ppmlim)) - 1
axis = [plotIn.getAxes(ppmlim=ppmlim)[position]]
y_data = [np.real(plotIn.get_spec(ppmlim=ppmlim))[position]]
trace = go.Scatter(x=axis, y=y_data,
......
......@@ -95,10 +95,13 @@ def synthetic_spectra_from_model(config_file,
'baseline': [0, 0] * (baseline_order + 1),
'conc': concentrations}
def_vals_int = {}
for key in defined_vals:
if isinstance(defined_vals[key], str) \
and defined_vals[key] in std_vals:
defined_vals[key] = std_vals[defined_vals[key]]
def_vals_int[key] = std_vals[defined_vals[key]]
else:
def_vals_int[key] = defined_vals[key]
rng = np.random.default_rng()
......@@ -106,14 +109,14 @@ def synthetic_spectra_from_model(config_file,
for index, param in enumerate(vm.mapped_names):
beh = vm.Parameters[param]
if beh == 'fixed':
if param in defined_vals:
if hasattr(defined_vals[param], "__len__") \
and len(defined_vals[param]) == vm.mapped_sizes[index]:
syn_free_params.extend(defined_vals[param])
elif hasattr(defined_vals[param], "__len__"):
if param in def_vals_int:
if hasattr(def_vals_int[param], "__len__") \
and len(def_vals_int[param]) == vm.mapped_sizes[index]:
syn_free_params.extend(def_vals_int[param])
elif hasattr(def_vals_int[param], "__len__"):
raise ValueError('Must be the same length as sizes.')
else:
syn_free_params.extend([defined_vals[param], ] * vm.mapped_sizes[index])
syn_free_params.extend([def_vals_int[param], ] * vm.mapped_sizes[index])
elif param in std_vals:
if hasattr(std_vals[param], "__len__") \
and len(std_vals[param]) == vm.mapped_sizes[index]:
......@@ -135,23 +138,54 @@ def synthetic_spectra_from_model(config_file,
current_bounds = [-1, 1]
syn_free_params.extend(rng.uniform(current_bounds[0], current_bounds[1], size=vm.mapped_sizes[index]))
elif beh == 'variable':
pass
if param in def_vals_int:
if hasattr(def_vals_int[param], "__len__") \
and len(def_vals_int[param]) == (vm.mapped_sizes[index] * vm.ntimes):
syn_free_params.extend(def_vals_int[param])
elif hasattr(def_vals_int[param], "__len__"):
raise ValueError('Must be the same length as sizes.')
else:
syn_free_params.extend([def_vals_int[param], ] * (vm.mapped_sizes[index] * vm.ntimes))
elif param in std_vals:
if hasattr(std_vals[param], "__len__") \
and len(std_vals[param]) == (vm.mapped_sizes[index] * vm.ntimes):
syn_free_params.extend(std_vals[param])
elif hasattr(std_vals[param], "__len__") \
and not len(std_vals[param]) % vm.mapped_sizes[index]:
syn_free_params.extend(std_vals[param] * vm.ntimes)
elif hasattr(std_vals[param], "__len__"):
raise ValueError('Must be the same length as sizes.')
else:
syn_free_params.extend([std_vals[param], ] * (vm.mapped_sizes[index] * vm.ntimes))
else:
if vm.defined_bounds is not None \
and param in vm.defined_bounds:
current_bounds = list(vm.defined_bounds[param])
if current_bounds[0] is None:
current_bounds[0] = -1
if current_bounds[1] is None:
current_bounds[1] = 1
else:
current_bounds = [-1, 1]
syn_free_params.extend(rng.uniform(current_bounds[0],
current_bounds[1],
size=(vm.mapped_sizes[index] * vm.ntimes)))
elif 'dynamic' in beh:
dyn_name = vm.Parameters[param]['params']
for x in range(vm.mapped_sizes[index]):
for y in dyn_name:
if y in defined_vals:
if hasattr(defined_vals[y], "__len__") \
and len(defined_vals[y]) == vm.mapped_sizes[index]:
syn_free_params.append(defined_vals[y][x])
elif hasattr(defined_vals[y], "__len__"):
if y in def_vals_int:
if hasattr(def_vals_int[y], "__len__") \
and len(def_vals_int[y]) == vm.mapped_sizes[index]:
syn_free_params.append(def_vals_int[y][x])
elif hasattr(def_vals_int[y], "__len__"):
raise ValueError('Must be the same length as sizes.')
else:
syn_free_params.append(defined_vals[y])
syn_free_params.append(def_vals_int[y])
elif y in std_vals:
if hasattr(std_vals[y], "__len__") \
and len(defined_vals[y]) == vm.mapped_sizes[index]:
and len(def_vals_int[y]) == vm.mapped_sizes[index]:
syn_free_params.append(std_vals[y][x])
elif hasattr(std_vals[y], "__len__"):
raise ValueError('Must be the same length as sizes.')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment