Skip to content
Snippets Groups Projects
Commit 6e2c3e21 authored by Hossein Rafipoor's avatar Hossein Rafipoor
Browse files

changed name

parent e3f2a186
No related branches found
No related tags found
No related merge requests found
# Thi module contains implementation of the biophysical models for diffusion.
import numpy as np
import scipy.stats as st
from . import utils
class BallStick:
def __init__(self, n_sticks, bvals, bvecs):
self.param_names = ['d', 'f_0'] + [f'{p}_{i + 1}' for i in range(n_sticks) for p in ['f', 'phi', 'theta']]
self.bvals = bvals
self.bvecs = bvecs
self.n_sticks = n_sticks
self.angle_idx = np.array([[self.param_names.index(f'{p}_{i}')
for i in range(1, n_sticks + 1) if f'{p}_{i}' in self.param_names]
for p in ['phi', 'theta']]).T
self.priors = {'d': st.truncnorm(loc=1.7, scale=0.3, a=-1.7 / 0.3, b=np.inf),
'f_0': st.uniform(loc=0.0, scale=1)}
self.priors.update({f'{p}_{i + 1}': v for i in range(n_sticks)
for p, v in zip(['f', 'phi', 'theta'], [st.uniform(loc=0.0, scale=1),
st.uniform(loc=0.0, scale=np.pi),
st.uniform(loc=0.0, scale=np.pi)])})
self.priors.update({'sigma2_g': st.invgamma(a=40, ), 'sigma2_n': st.invgamma(a=1e2)})
def compute(self, params):
"""
gets params as a dictionary and compute the signal
"""
s = params['f_0'] * np.exp(-params['d'] * self.bvals)
for i in range(1, self.n_sticks + 1):
s += np.squeeze(params[f'f_{i}'] * np.exp(-params['d'] * self.bvals * utils.p2c(params[f'phi_{i}'],
params[f'theta_{i}'])[np.newaxis,
:].dot(self.bvecs.T) ** 2))
return s
def compute_vec(self, params_vec):
params = {k: v for k, v in zip(self.param_names, params_vec)}
return self.compute(params)
\ No newline at end of file
import numpy as np
import scipy.stats as st
from joblib import Parallel, delayed
from typing import Callable, List, Mapping
from dataclasses import dataclass
from . import utils
@dataclass
class hbm():
forward_model: Callable
param_names: List
param_priors: Mapping
angles_idx: np.ndarray
"""
fits hierarchical bayes model to data
args:
return:
"""
def __post_init__(self):
self.bounds_s = np.array([self.param_priors[p].support() for p in self.param_names] +
[self.param_priors['sigma2_n'].support()])
self.bounds_g = np.array([self.param_priors[p].support() for p in self.param_names] +
[self.param_priors['sigma2_g'].support()] * len(self.param_names))
def likelihood_single_subj(self, model_params, sigma2_n, data):
assert sigma2_n > 0
x = self.forward_model(model_params)
return st.multivariate_normal(mean=x, cov=sigma2_n ** 0.5).logpdf(data).sum()
def prior_single_subj(self, subj_params, group_params, group_sigma2, sigma2_n):
p1 = np.sum([st.norm(loc=subj_params[p], scale=group_sigma2[p] ** 0.5).logpdf(subj_params[p])
for p in subj_params.keys()])
p2 = self.param_priors['sigma2_n'].logpdf(sigma2_n)
return p1 + p2
def posterior_single_subj(self, all_subj_params, all_group_params, data):
"""
args:
all_subj_params: array (n+1,) the first n is model parameters the last is noise sigma
all_group_params: (2n) first row are the means, secound row is the variances.
"""
subj_params = {k: v for k, v in zip(self.param_names, all_subj_params[:-1])}
sigma2_n = all_subj_params[-1]
group_params = {k: v for k, v in zip(self.param_names, all_group_params[:len(self.param_names)])}
group_sigma2 = {k: v for k, v in zip(self.param_names, all_group_params[len(self.param_names):])}
return self.prior_single_subj(subj_params, group_params, group_sigma2, sigma2_n) + \
self.likelihood_single_subj(subj_params, sigma2_n, data)
def likelihood_g(self, group_params, group_sigma2, subj_params):
p1 = np.sum([st.norm(loc=group_params[p], scale=group_sigma2[p] ** 0.5).logpdf(subj_params[p])
for p in self.param_names])
return p1
def prior_g(self, group_params, group_sigma2):
p1 = np.sum([self.param_priors[p].logpdf(group_params[p]) for p in self.param_names])
p2 = np.sum([self.param_priors['sigma2_g'].logpdf(group_sigma2[p]) for p in self.param_names])
return p1 + p2
def posterior_g(self, all_group_params, all_subj_params):
"""
args:
all_group_params: (2n,) n is the number of params, first row is means, second row is variance
all_subj_params: (k, n+1) each row is params for one subj, each column one parameter,
last column is noise variance (not used in this function, passed for keeping consistency)
"""
group_params = {k: v for k, v in zip(self.param_names, all_group_params[:len(self.param_names)])}
group_sigma2 = {k: v for k, v in zip(self.param_names, all_group_params[len(self.param_names):])}
subj_params = {k: v for k, v in zip(self.param_names, all_subj_params[:, :, :-1].T)}
return self.prior_g(group_params, group_sigma2) + \
self.likelihood_g(group_params, group_sigma2, subj_params)
def fit(self, data, jumps=1000, skips=5, burnin=100, iters=20):
def single_subj_fit(s):
samples, probs = utils.mcmc(posterior=self.posterior_single_subj,
args=(current_g, data[s]),
p0=utils.average_samples(current_s[s], self.angles_idx),
bounds=self.bounds_s, burnin=burnin, jumps=jumps, skips=skips, step_size=1e-2)
return samples, probs
n_subj = data.shape[0]
n_samples = jumps // skips
current_g = np.array([self.param_priors[p].rvs(n_samples) for p in self.param_names] +
list(self.param_priors['sigma2_g'].rvs((len(self.param_names), n_samples)))).T
current_s = np.stack([self.param_priors[p].rvs((n_subj, n_samples)) for p in self.param_names] +
[self.param_priors['sigma2_n'].rvs((n_subj, n_samples))], axis=-1)
best_probs = -np.inf
for t in range(iters):
res = Parallel(n_jobs=-1)(delayed(single_subj_fit)(i) for i in range(n_subj))
current_s = np.stack([r[0] for r in res])
probs_s = np.mean([r[1].mean() for r in res])
res, probs_g = utils.mcmc(posterior=self.posterior_g, args=[current_s],
p0=utils.average_samples(current_g, self.angles_idx),
bounds=self.bounds_g, burnin=burnin, jumps=jumps, skips=skips)
current_g = np.array(res)
total_probs = probs_g.mean() + probs_s
if best_probs <= total_probs:
best_params = (current_g, current_s)
print(t, total_probs)
# current_g, current_s = best_params
res = Parallel(n_jobs=-1)(delayed(single_subj_fit)(i) for i in range(n_subj))
subj_samples = np.stack([np.squeeze(r[0]) for r in res], axis=0)
group_samples, group_probs = utils.mcmc(posterior=self.posterior_g, args=[current_s],
p0=current_g.mean(axis=0), bounds=self.bounds_g,
burnin=burnin, jumps=jumps, skips=skips, step_size=1e-2)
return group_samples, subj_samples
import numpy as np
import matplotlib.pyplot as plt
def watson_pdf(x, mu, k):
"""
computes un-normalized watson distribution.
args:
x: array (n, 3)
mu : vector(3,)
k : scalar
return:
(n,) watson pdf
"""
t = (x @ mu.T) / np.linalg.norm(x, axis=1) / np.linalg.norm(mu)
return np.exp(k * t ** 2)
def p2c(phi, theta, r=1):
return np.array([np.sin(theta) * np.cos(phi), np.sin(theta) * np.sin(phi), np.cos(theta)]).T * r
def c2p(x, y, z):
x, y, z = [np.atleast_1d(t) for t in [x, y, z]]
r = (x ** 2 + y ** 2 + z ** 2) ** 0.5
phi = np.zeros_like(r)
theta = np.zeros_like(r)
phi[r == z] = 0
theta[r == z] = 0
phi[r != z] = np.arccos(x[r != z] / ((r[r != z] ** 2 - z[r != z] ** 2) ** 0.5))
theta[r != z] = np.arccos(z[r != z] / r[r != z])
return np.array([phi, theta, r]).T
assert (p2c(0, 0, 1) == (0, 0, 1)).all()
assert (c2p(0, 0, 1) == (0, 0, 1)).all()
assert (c2p(*p2c(np.pi / 3, np.pi / 2, 1)) == (np.pi / 3, np.pi / 2, 1)).all()
def ball_sticks_1(d, f_0, f_1, phi_1, theta_1, bvals, bvecs):
"""
Takes the parameters of ball & stick model and acquistion parameters to generate diffusion signal.
"""
assert 0 <= f_1 <= 1
signal = f_0 * np.exp(-d * bvals) + f_1 * np.exp(-d * bvals * p2c(phi_1, theta_1)[np.newaxis, :].dot(bvecs.T) ** 2)
return np.squeeze(signal)
def fibonacci_sphere(samples=1):
"""
Creates N points uniformly-ish distributed on the sphere
Args:
samples : int
"""
points = np.array((samples, 3))
phi = np.pi * (3. - np.sqrt(5.)) # golden angle in radians
i = np.arange(samples)
y = 1 - 2. * (i / float(samples - 1))
r = np.sqrt(1 - y * y)
t = phi * i
x = np.cos(t) * r
z = np.sin(t) * r
points = np.asarray([x, y, z]).T
return points
def plot_sphere(r=1, alpha=0.5):
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(projection='3d')
u, v = np.mgrid[0:2 * np.pi:50j, 0:np.pi:40j]
x = r * np.cos(u) * np.sin(v)
y = r * np.sin(u) * np.sin(v)
z = r * np.cos(v)
ax.plot_surface(x, y, z, color='lightgray', alpha=alpha)
return ax
def average_samples(samples, angle_idx):
"""
take samples and computes the average, everything is ordinary averaging except for angles dyadic average is used.
args:
samples: (n, d)
angle_idx : (k, 2) array, or nested list that contains index of phi and theta for each vector
return:
(1, d)
"""
angle_idx = np.atleast_2d(angle_idx)
mean_samples = np.mean(samples, axis=0)
for (phi_idx, theta_idx) in angle_idx:
d = p2c(samples[:, phi_idx], samples[:, theta_idx])
d_avg = np.linalg.eigh(np.einsum('ik,ip->kp', d, d))[1][..., -1]
phi_avg, theta_avg, _ = c2p(*d_avg).T
mean_samples[phi_idx] = phi_avg
mean_samples[theta_idx] = theta_avg
return mean_samples
def mcmc(posterior, args, p0, cov=None, bounds=None, step_size=1e0, jumps=5000, burnin=100, skips=10):
if cov is None:
cov = np.eye(len(p0))
cov = np.atleast_2d(cov)
if np.all(np.linalg.eigvals(cov) > 1e-14):
L1 = np.linalg.cholesky(cov) / np.sqrt(len(p0))
else:
L1 = np.linalg.cholesky(np.eye(cov.shape[0])) / np.sqrt(len(p0))
if bounds is None:
bounds = np.array([[-np.inf, np.inf]] * len(p0))
assert (p0 >= bounds[:, 0]).all() and (p0 <= bounds[:, 1]).all()
current = np.array(p0)
prob_cur = posterior(current, *args)
samples = []
all_probs = []
for j in range(jumps + burnin):
proposed = current + L1 @ np.random.randn(*current.shape) * step_size
if (proposed >= bounds[:, 0]).all() and (proposed <= bounds[:, 1]).all():
prob_next = posterior(proposed, *args)
if np.exp(prob_next - prob_cur) > np.random.rand():
current = proposed
prob_cur = prob_next
samples.append(current)
all_probs.append(prob_cur)
return np.squeeze(np.stack(samples, axis=0))[burnin::skips], np.array(all_probs)[burnin::skips]
def iterative_mcmc(posterior, args, p0, jumps=5000, bounds=None, repeats=10, burnin=100, skips=5, step_size=1e0):
current_cov = np.eye(len(p0)) * 1e-3
for _ in range(repeats):
results, _ = mcmc(posterior, args, p0, current_cov, bounds=bounds, step_size=step_size,
jumps=jumps // repeats, burnin=burnin // repeats, skips=1)
current_cov = np.cov(results.T)
return mcmc(posterior, args, p0, current_cov, bounds=bounds, jumps=jumps,
burnin=burnin, skips=skips, step_size=step_size)
def generate_acq(n_b0=10, n_dir=64, b=[1, 2, 3]):
bvals = np.zeros(n_b0)
bvecs = fibonacci_sphere(n_b0)
for b_ in b:
bvals = np.concatenate([bvals, np.ones(n_dir) * b_])
bvecs = np.concatenate([bvecs, fibonacci_sphere(n_dir)])
return bvals, bvecs
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment