Skip to content
Snippets Groups Projects
Commit f4305cef authored by Hossein Rafipoor's avatar Hossein Rafipoor
Browse files

added mcmc function

added parallel option for iterative fitting
parent 96175907
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,7 @@ class BallStick:
st.uniform(loc=0.0, scale=np.pi),
st.uniform(loc=0.0, scale=np.pi)])})
self.priors.update({'sigma2_g': st.invgamma(a=40, ), 'sigma2_n': st.invgamma(a=1e2)})
# self.priors.update({'sigma2_g': st.invgamma(a=40, ), 'sigma2_n': st.invgamma(a=1e2)})
def compute(self, params):
"""
......
......@@ -7,6 +7,7 @@ from . import utils
from scipy.optimize import minimize
@dataclass
class hbm:
forward_model: Callable
......@@ -120,6 +121,100 @@ class hbm:
return group_samples, subj_samples
def hbm_mcmc(data, forward_model, priors, angles_idx, a_n, a_g, jumps=1000, skips=5, burnin=100, iters=20):
param_names = priors.keys()
sigma2_g_prior = st.invgamma(a=a_g)
sigma2_s_prior = st.invgamma(a=a_n)
bounds_s = np.array([priors[p].support() for p in param_names] + [sigma2_s_prior.support()])
bounds_g = np.array([priors[p].support() for p in param_names] + [sigma2_g_prior.support()] * len(param_names))
def likelihood_subj(theta_s, sigma2_s, data_s):
assert sigma2_s > 0
x = forward_model(theta_s)
return st.multivariate_normal(mean=x, cov=sigma2_s ** 0.5).logpdf(data_s).sum()
def prior_subj(theta_s, theta_g, sigma2_g, sigma2_s):
p1 = np.sum([st.norm(loc=tg, scale=sg ** 0.5).logpdf(ts) for tg, sg, ts in zip(theta_g, sigma2_g, theta_s)])
p2 = sigma2_s_prior.logpdf(sigma2_s)
return p1 + p2
def posterior_subj(subj_params, group_params, data_s):
"""
args:
all_subj_params: array (n+1,) the first n is model parameters the last is noise sigma
all_group_params: (2n) first row are the means, secound row is the variances.
"""
theta_s = subj_params[:-1]
sigma2_s = subj_params[-1]
theta_g = group_params[:len(param_names)]
sigma2_g = group_params[len(param_names):]
return prior_subj(theta_s, theta_g, sigma2_g, sigma2_s) + likelihood_subj(theta_s, sigma2_s, data_s)
def likelihood_group(theta_g, sigma2_g, theta_s):
p1 = np.sum([st.norm(loc=theta_g[p], scale=sigma2_g[p] ** 0.5).logpdf(theta_s[p]) for p in range(len(theta_g))])
return p1
def prior_group(theta_g, sigma2_g):
p1 = np.sum([p.logpdf(t) for p, t in zip(priors.values(), theta_g)])
p2 = sigma2_g_prior.logpdf(sigma2_g).sum()
return p1 + p2
def posterior_group(group_params, subj_params):
"""
args:
all_group_params: (2n,) n is the number of params, first row is means, second row is variance
all_subj_params: (k, n+1) each row is params for one subj, each column one parameter,
last column is noise variance (not used in this function, passed for keeping consistency)
"""
theta_g = group_params[:len(param_names)]
sigma2_g = group_params[len(param_names):]
theta_s = subj_params[:, :, :-1].T
return prior_group(theta_g, sigma2_g) + likelihood_group(theta_g, sigma2_g, theta_s)
def single_subj_fit(s):
samples, probs = utils.mcmc(posterior=posterior_subj,
args=[current_g, data[s]],
p0=utils.average_samples(current_s[s], angles_idx),
bounds=bounds_s, burnin=burnin, jumps=jumps, skips=skips, step_size=1e-2)
return samples, probs
n_subj = data.shape[0]
n_samples = jumps // skips
current_g = np.array([priors[p].rvs(n_samples) for p in param_names] +
list(sigma2_g_prior.rvs((len(param_names), n_samples)))).T
current_s = np.stack([priors[p].rvs((n_subj, n_samples)) for p in param_names] +
[sigma2_s_prior.rvs((n_subj, n_samples))], axis=-1)
best_probs = -np.inf
for t in range(iters):
res = Parallel(n_jobs=-1)(delayed(single_subj_fit)(i) for i in range(n_subj))
current_s = np.stack([r[0] for r in res])
probs_s = np.mean([r[1].mean() for r in res])
res, probs_g = utils.mcmc(posterior=posterior_group, args=[current_s],
p0=utils.average_samples(current_g, angles_idx),
bounds=bounds_g, burnin=burnin, jumps=jumps, skips=skips)
current_g = np.array(res)
total_probs = probs_g.mean() + probs_s
if best_probs <= total_probs:
best_params = (current_g, current_s)
print(t, total_probs)
# current_g, current_s = best_params
res = Parallel(n_jobs=-1)(delayed(single_subj_fit)(i) for i in range(n_subj))
subj_samples = np.stack([np.squeeze(r[0]) for r in res], axis=0)
group_samples, group_probs = utils.mcmc(posterior=posterior_group, args=[current_s],
p0=current_g.mean(axis=0), bounds=bounds_g,
burnin=burnin, jumps=jumps, skips=skips, step_size=1e-2)
return group_samples, subj_samples
def fit_full_posterior(data, forward_model, bounds, a_n, b_n, a_g, b_g):
"""
Computes the parameters by optimizing the full posterior distribution.
......@@ -149,38 +244,47 @@ def fit_full_posterior(data, forward_model, bounds, a_n, b_n, a_g, b_g):
return (a_n + dims / 2) * c_d + (a_g + (n_subj - 1) / 2) * c_g
obj = minimize(fun=cost_func, x0=np.random.rand(n_subj * n_params), bounds=bounds)
obj = minimize(fun=cost_func, x0=np.random.rand(n_subj * n_params), bounds=np.tile(bounds, (n_subj, 1)))
x = np.reshape(obj.x, (n_subj, n_params))
h = utils.hessian(cost_func, obj.x, bounds)
h = utils.hessian(cost_func, obj.x, np.tile(bounds, (n_subj, 1)))
return x, h
def fit_iterative_posterior(data, forward_model, bounds, a_n, b_n, a_g, b_g, max_iters=100):
def fit_iterative_posterior(data, forward_model, bounds, a_n, b_n, a_g, b_g, max_iters=100, parallel=True):
n_subj, dims = data.shape
n_params = bounds.shape[0]
current_g = np.array([p[0] + np.random.rand() * (p[1] - p[0]) for p in bounds] +
list(st.invgamma(a=a_g).rvs(n_params))).T
current_s = np.stack([p[0] + np.random.rand(n_subj) * (p[1] - p[0]) for p in bounds], axis=-1)
# MLE parameters for each subject without hierarchy as initial guess:
for s in range(n_subj):
f = lambda x: np.linalg.norm(forward_model(x) - data[s])
obj = minimize(fun=f, x0=current_s[s], bounds=bounds)
current_s[s] = obj.x
def cost_func_subj(theta_s, data_s):
theta_g, sigma2_g = current_g[:n_params], current_g[n_params:]
c_d = np.log(np.linalg.norm(forward_model(theta_s) - data_s) ** 2 + 2 * b_n)
c_g = np.sum([st.norm(loc=t, scale=s ** 0.5).logpdf(m) for t, s, m in zip(theta_g, sigma2_g, theta_s)])
c_d = -np.log(np.linalg.norm(forward_model(theta_s) - data_s) ** 2 + 2 * b_n)
c_g = np.sum([st.norm(loc=tg, scale=sg ** 0.5).logpdf(ts) for tg, sg, ts in zip(theta_g, sigma2_g, theta_s)])
return (a_n + dims / 2) * c_d - c_g
return -((a_n + dims / 2) * c_d + c_g)
probs_s = np.zeros((max_iters, n_subj))
for t in range(max_iters):
for s in range(n_subj):
obj = minimize(fun=cost_func_subj, x0=current_s[s], args=(data[s]), bounds=bounds)
current_s[s] = obj.x
probs_s[t, s] = obj.fun
theta_g = current_s.mean(axis=0)
sigma_g = (np.linalg.norm(current_s - theta_g[np.newaxis, :]) ** 2 / 2 + b_g) / (a_g + n_subj / 2)
current_g = np.r_[theta_g, sigma_g]
sigma2_g = (np.linalg.norm(current_s - theta_g[np.newaxis, :], axis=0) ** 2 + 2 * b_g) / (2 * a_g + n_subj + 1)
current_g = np.r_[theta_g, sigma2_g]
def optimize_subj(s):
obj = minimize(fun=cost_func_subj, x0=current_s[s], args=(data[s]), bounds=bounds)
return obj.x, obj.fun
if parallel is True:
res = Parallel(n_jobs=-1)(delayed(optimize_subj)(i) for i in range(n_subj))
current_s = np.stack([r[0] for r in res], axis=0)
probs_s[t] = np.stack([r[1] for r in res], axis=0)
else:
for s in range(n_subj):
current_s[s], probs_s[t, s] = optimize_subj(s)
print(t, probs_s[t].sum())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment