Commit 21eccc87 authored by William Clarke's avatar William Clarke
Browse files

Initial save and load for dynamic results implementation.

parent 82d09bed
......@@ -7,6 +7,7 @@ This document contains the FSL-MRS release history in reverse chronological orde
- Modified API of syntheticFromBasis function.
- Dynamic fitting now handles multiple different basis sets.
- Fix mapped parameter uncertainties in dynamic MRS results.
- Dynamic fitting results can now be saved to and loaded from a directory.
1.1.8 (Tuesday 5th October 2021)
-------------------------------
......
......@@ -52,6 +52,8 @@ def test_dynRes(fixed_ratio_mrs):
rescale=False)
res = dyn_obj.fit()
resinit = dyn_obj.initialise()
res_obj = res['result']
import plotly
......@@ -66,6 +68,9 @@ def test_dynRes(fixed_ratio_mrs):
assert isinstance(res_obj.x, np.ndarray)
assert res_obj.x.shape[0] == res_obj.data_frame.shape[1]
assert isinstance(res_obj._init_x, pd.DataFrame)
assert np.allclose(dyn_obj.vm.mapped_to_free(resinit['x']), res_obj.free_parameters_init)
assert isinstance(res_obj.mapped_parameters, np.ndarray)
assert res_obj.mapped_parameters.shape == (1, len(mrs_list), len(dyn_obj.mapped_names))
......@@ -138,3 +143,31 @@ def test_dynRes_mcmc(fixed_ratio_mrs):
assert res_obj.std.shape == (10,)
assert np.allclose(res_obj.std, np.sqrt(np.diagonal(res_obj.cov)))
def test_load_save(fixed_ratio_mrs, tmp_path):
mrs_list = fixed_ratio_mrs
dyn_obj = dyn.dynMRS(
mrs_list,
[0, 1],
'fsl_mrs/tests/testdata/dynamic/simple_linear_model.py',
model='lorentzian',
baseline_order=0,
metab_groups=[0, 0],
rescale=False)
res = dyn_obj.fit()['result']
res.save(tmp_path / 'res_save_test')
res_loaded = dyn.load_dyn_result(tmp_path / 'res_save_test', dyn_obj)
from pandas._testing import assert_frame_equal
assert_frame_equal(res._data, res_loaded._data)
assert_frame_equal(res._init_x, res_loaded._init_x)
res.save(tmp_path / 'res_save_test2', pickle_dyn=True)
res_loaded2 = dyn.load_dyn_result(tmp_path / 'res_save_test2')
assert_frame_equal(res._data, res_loaded2._data)
assert_frame_equal(res._init_x, res_loaded2._init_x)
from fsl_mrs.utils.dynamic.variable_mapping import VariableMapping
from fsl_mrs.utils.dynamic.dynmrs import dynMRS
from fsl_mrs.utils.dynamic.dyn_results import load_dyn_result
......@@ -6,14 +6,56 @@
# Copyright (C) 2021 University of Oxford
import copy
import warnings
import dill as pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from fsl_mrs.utils.misc import calculate_lap_cov, gradient
class ResultLoadError(Exception):
pass
# Loading function
def load_dyn_result(load_dir, dyn_obj=None):
"""Load a saved dynamic fitting result from directory.
The directory should cointain two csv files (dyn_results and init_results).
And either the user must pass the asociated dynMRS object as the dyn_obj
argument, or the directory must also contain a dyn.pkl file.s
:param load_dir: Directory to load. Creaed using dynMRS.save method
:type load_dir: str or pathlib.Path
:param dyn_obj: Associated dynMRS object or if None will attempt to load a dyn.pkl file, defaults to None
:type dyn_obj: fsl_mrs.utils.dynamic.dynMRS, optional
:return: Dynamic results object
:rtype: dynRes_newton or dynRes_mcmc
"""
if not isinstance(load_dir, Path):
load_dir = Path(load_dir)
sample_df = pd.read_csv(load_dir / 'dyn_results.csv', index_col='samples')
init_df = pd.read_csv(load_dir / 'init_results.csv', index_col=0)
if sample_df.shape[0] == 1:
cls = dynRes_newton
else:
cls = dynRes_mcmc
if dyn_obj:
return cls(sample_df, dyn_obj, init_df)
elif (load_dir / 'dyn.pkl').is_file():
with open(load_dir / 'dyn.pkl', 'rb') as pickle_file:
dyn_obj = pickle.load(pickle_file)
return cls(sample_df, dyn_obj, init_df)
else:
raise ResultLoadError('Dynamic object required. Pass directly or ensure dyn.pkl is availible')
# Plotting functions:
def subplot_shape(plots_needed):
"""Calculates the number and shape of subplots needed for
......@@ -47,18 +89,50 @@ class dynRes:
:param samples: Array of free parameters returned by optimiser, can be 2D in mcmc case.
:type samples: numpy.ndarray
:type init: pd.DataFrame
:param dyn: Copy of dynMRS class object.
:type dyn: fsl_mrs.utils.dynamic.dynMRS
:param init: Results of the initilisation optimisation, containing 'resList' and 'x'.
:type init: dict
:type init: pd.DataFrame
"""
self._data = pd.DataFrame(data=samples, columns=dyn.free_names)
self._data.index.name = 'samples'
if isinstance(samples, pd.DataFrame):
self._data = samples
else:
self._data = pd.DataFrame(data=samples, columns=dyn.free_names)
self._data.index.name = 'samples'
self._dyn = copy.deepcopy(dyn)
# Store the init mapped representation
self._init_x = init['x']
# Store the init as dataframe
if isinstance(init, pd.DataFrame):
self._init_x = init
else:
self._init_x = pd.DataFrame(self.flatten_mapped(init['x']), columns=self._dyn.mapped_names)
def save(self, save_dir, pickle_dyn=False):
"""Save the results to a directory
Saves the two dataframes to csv format. If pickle_dyn=True then the ._dyn object
is also saved.
:param save_dir: Location to save to, created if neccesary.
:type save_dir: str or pathlib.Path
:param pickle_dyn: Save _dyn dynMRS object to pickle file, defaults to False
:type pickle_dyn: bool, optional
"""
if not isinstance(save_dir, Path):
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
# Save the two dataframes
self._data.to_csv(save_dir / 'dyn_results.csv')
self._init_x.to_csv(save_dir / 'init_results.csv')
# If selected save the dynamic object a pickle file
if pickle_dyn:
with open(save_dir / 'dyn.pkl', 'wb') as fp:
pickle.dump(self._dyn, fp)
@property
def data_frame(self):
......@@ -85,6 +159,23 @@ class dynRes:
flattened.append(np.hstack(mp))
return np.asarray(flattened)
@staticmethod
def nest_mapped(mapped, vm):
"""Nest a flattened array of mapped parameters.
:param mapped: Flattened array representation
:type mapped: np.array
:param vm: VariableMapping object
:type vm: fsl_mrs.utisl.dynamic.VariableMapping
:return: Nested array
:rtype: np.array
"""
nested = []
for mp in mapped:
tmp = np.split(mp, np.cumsum(vm.mapped_sizes))[:-1]
nested.append(tmp)
return np.asarray(nested, dtype=object)
@property
def mapped_parameters(self):
"""Flattened mapped parameters. Shape is samples x timepoints x parameters.
......@@ -106,7 +197,7 @@ class dynRes:
:return: Flattened mapped parameters from initilisation
:rtype: np.array
"""
return self.flatten_mapped(self._init_x)
return self._init_x.to_numpy()
@property
def free_parameters_init(self):
......@@ -115,7 +206,8 @@ class dynRes:
:return: Free parameters estimated from initilisation
:rtype: np.array
"""
return self._dyn.vm.mapped_to_free(self._init_x)
nested_init = self.nest_mapped(self.mapped_parameters_init, self._dyn.vm)
return self._dyn.vm.mapped_to_free(nested_init)
@property
def init_dataframe(self):
......@@ -411,13 +503,16 @@ class dynRes_newton(dynRes):
:param init: Results of the initilisation optimisation, containing 'resList' and 'x'.
:type init: dict
"""
super().__init__(samples[np.newaxis, :], dyn, init)
if isinstance(samples, pd.DataFrame):
super().__init__(samples, dyn, init)
else:
super().__init__(samples[np.newaxis, :], dyn, init)
# Calculate covariance, correlation and uncertainty
data = np.asarray(dyn.data).flatten()
# Dynamic (free) parameters
self._cov_dyn = calculate_lap_cov(samples, dyn.full_fwd, data)
self._cov_dyn = calculate_lap_cov(self.x, dyn.full_fwd, data)
crlb_dyn = np.diagonal(self._cov_dyn)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'invalid value encountered in sqrt')
......@@ -425,12 +520,12 @@ class dynRes_newton(dynRes):
self._corr_dyn = self._cov_dyn / (self._std_dyn[:, np.newaxis] * self._std_dyn[np.newaxis, :])
# Mapped parameters
p = dyn.vm.free_to_mapped(samples)
p = dyn.vm.free_to_mapped(self.x)
self._mapped_params = dyn.vm.mapped_to_dict(p)
# Mapped parameters covariance etc.
grad_all = np.transpose(gradient(samples, dyn.vm.free_to_mapped), (2, 0, 1))
grad_all = np.transpose(gradient(self.x, dyn.vm.free_to_mapped), (2, 0, 1))
N = dyn.vm.ntimes
M = len(samples)
M = len(self.x)
std = {}
for i, name in enumerate(dyn.vm.mapped_names):
s = []
......
......@@ -191,6 +191,7 @@ class dynMRS(object):
return {'result': results, 'resList': res_list, 'optimisation_sol': sol}
def fit_mean_spectrum(self):
"""Return the parameters from the fit of the mean spectra stored in mrs_list."""
from fsl_mrs.utils.preproc.combine import combine_FIDs
from copy import deepcopy
......
......@@ -10,3 +10,4 @@ hlsvdpropy
fslpy>=3.0
pillow
spec2nii>=0.3.0
dill
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment