Commit d8d2b204 authored by William Clarke's avatar William Clarke
Browse files

Fixes for new stateless tresults class format.

parent 8ffbcef9
Pipeline #14861 failed with stages
......@@ -483,13 +483,13 @@ class dynMRS(object):
dynresList = []
for t in range(self.vm.ntimes):
mrs = self.mrs_list[t]
results = FitRes(self._fit_args['model'],
results = FitRes(mrs,
mapped[t],
self._fit_args['model'],
method,
mrs.names,
metab_groups,
self._fit_args['baseline_order'],
base_poly,
self._fit_args['ppmlim'])
results.loadResults(mrs, mapped[t])
dynresList.append(results)
return dynresList
......@@ -90,9 +90,16 @@ def main():
B = prepare_baseline_regressor(mrs, baseline_order, ppmlim)
# Generate results object
res = results.FitRes(model, method, mrs.names, metab_groups, baseline_order, B, ppmlim)
res.loadResults(mrs, param_df['mean'].to_numpy())
# res.params = param_df['mean'].to_numpy()
print(metab_groups)
res = results.FitRes(
mrs,
param_df['mean'].to_numpy(),
model,
method,
metab_groups,
baseline_order,
B,
ppmlim)
if orig_args['combine'] is not None:
res.combine(orig_args['combine'])
......
......@@ -317,7 +317,7 @@ def fit_FSLModel(mrs,
results = FitRes(mrs, res.x, model, method, metab_groups, baseline_order, B, ppmlim)
elif method == 'init':
results.loadResults(mrs, x0)
results = FitRes(mrs, x0, model, method, metab_groups, baseline_order, B, ppmlim)
elif method == 'MH':
def forward_mh(p):
......@@ -406,7 +406,7 @@ def fit_FSLModel(mrs,
samples = mcmc.fit(p0, LB=LB, UB=UB, verbose=False, mask=mask)
# collect results
results.loadResults(mrs, samples)
results = FitRes(mrs, samples, model, method, metab_groups, baseline_order, B, ppmlim)
elif method == 'VB':
import warnings
......@@ -470,7 +470,7 @@ def fit_FSLModel(mrs,
x = p2x(np.exp(logcon), np.exp(loggamma), np.exp(logsigma), eps, phi0, phi1, b)
# collect results
results.loadResults(mrs, x)
results = FitRes(mrs, x, model, method, metab_groups, baseline_order, B, ppmlim, vb_optim=res_vb)
else:
raise Exception('Unknown optimisation method.')
......
......@@ -23,7 +23,7 @@ class FitRes(object):
Collects fitting results
"""
def __init__(self, mrs, results, model, method, metab_groups, baseline_order, B, ppmlim, runqc=True):
def __init__(self, mrs, results, model, method, metab_groups, baseline_order, B, ppmlim, runqc=True, vb_optim=None):
"""_summary_
_extended_summary_
......@@ -116,8 +116,8 @@ class FitRes(object):
# VB metrics
if self.method == 'VB':
self.vb_cov = self.optim_out.cov
self.vb_var = self.optim_out.var
self.vb_cov = vb_optim.cov
self.vb_var = vb_optim.var
std = np.sqrt(self.vb_var)
self.vb_corr = self.vb_cov / (std[:, np.newaxis] * std[np.newaxis, :])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment