Commit 1ed554d5 authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

Added data

parent 71cfec62
......@@ -288,8 +288,8 @@
"from sbi.inference.base import infer\n",
"\n",
"# read in bvals and bvecs\n",
"bvals = torch.tensor(np.loadtxt('/Users/saad/data/dtidata/bvals'),dtype=float)/1000.0\n",
"bvecs = torch.tensor(np.loadtxt('/Users/saad/data/dtidata/bvecs'),dtype=float)\n",
"bvals = torch.tensor(np.loadtxt('./data/bvals'),dtype=float)/1000.0\n",
"bvecs = torch.tensor(np.loadtxt('./data/bvecs'),dtype=float)\n",
"\n",
"# forward model is ball and stick\n",
"def forward(p):\n",
......
%% Cell type:markdown id: tags:
## Simulation Based Inference - Practical
In this practical we will:
- Create a toy SBI thingy from scratch
- Learn to use a proper SBI toolbox called ```sbi```...
Suppose you have a model of the form : ```param ---simulator--->data```
We are interested in making inference on the params given the data. In SBI, this is done in two steps:
- Step 1: Learn a nonlinear mapping : ```data --> posterior``` where the posterior is parameterised in some way. This is done via forward simulations.
- Step 2: Given some actual data, use the mapping learned in step 1 to directly make inference
Before you start, make sure the following Python packages are installed (sorry about this):
* numpy
* matplotlib
* keras
* tensorflow
* scipy
* torch
Let's build our own dumb SBI
%% Cell type:code id: tags:
``` python
# Simple 1D example
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# Simulator does theta -> x
# here x = theta^3+noise
def simulator(theta, noise_std=.1):
noise = noise_std * np.random.randn()
return theta**3+noise
# Generate pairs of theta/data using prior (doesn't work in practice!)
theta_prior_mean = 0.
theta_prior_std = 1.
thetas = []
xs = []
for n in range(1000):
th = np.random.randn()*theta_prior_std + theta_prior_mean
x = simulator(th)
thetas.append(th)
xs.append(x)
```
%% Cell type:code id: tags:
``` python
# Plot our samples
plt.plot(thetas,xs,'.')
plt.xlabel('theta')
plt.ylabel('x')
```
%% Cell type:markdown id: tags:
Now build a silly little neural net to generate params for approx posterior
%% Cell type:code id: tags:
``` python
from keras.models import Model
from keras.layers import Dense, Input
import tensorflow as tf
L_0 = Input(shape=(1,)) # Layer 0 takes in 1D data
L_1 = Dense(100,activation='tanh')(L_0) # Layer 1 has 100 nodes and a tanh activation
L_2 = Dense(2,activation='linear')(L_1) # Layer 2 outputs mean and logsig of the posterior
model = Model(L_0,L_2)
# This is the loss function: mean[ -log q(theta;x) ] where q is Gaussian
def my_loss(thetas,out_pred):
mu = out_pred[:,0][:,None] # mean
logsig = out_pred[:,1][:,None] # logsig
sig = tf.math.exp(logsig)
loss = (thetas - mu)**2/sig**2/2 + logsig
return tf.reduce_mean(loss)
# some random optimizer
optim = tf.keras.optimizers.Adagrad(
learning_rate=0.01, initial_accumulator_value=0.1, epsilon=1e-07,
name='Adagrad')
# compile the model (whatever that means in Keras world)
model.compile(loss=my_loss,optimizer=optim)
# train the model
history = model.fit(xs, thetas, epochs=100, batch_size=16,workers=4)
```
%% Cell type:markdown id: tags:
### What have we just done?
We have trained a neural net to take in some *simulated* data and output the mean and the log(std) of a Gaussian distribution.
Given that the data were simulated from the prior on theta, and given the form of our cost function, the optimal mean and log(std) are those of the posterior distribution (assuming it is a Gaussian).
For more "expressive" posterior distibutions, one can use mixtures of Gaussians for example.
Let's now visualise the fit to the training data by plotting the data on the x-axis and the parameters on the y-axis (compare this to the previous plot). We also plot the mean posterior distribution.
%% Cell type:code id: tags:
``` python
# Visualise the fit to the training data
mus = model.predict(xs)[:,0]
plt.plot(xs,thetas,'.')
plt.plot(xs,mus,'.')
plt.xlabel('data')
```
%% Cell type:markdown id: tags:
Now let us have a look at the posterior for a new data point that is not part of the training set. For this we need to 'observe' new data.
We will just generate those using our simulator
%% Cell type:code id: tags:
``` python
true_theta = np.random.randn()*theta_prior_std+theta_prior_mean # generate a parameter theta (here using the prior but we don't have to)
observation = simulator(true_theta) # simulate the data
mu,logsig = model.predict([observation])[0] # use our neural net to predict the parameters of the posterior
# Plotting the prior, posterior, and actual theta value
from scipy.stats import norm
x_axis = np.linspace(-2, 2,1000)
plt.figure()
plt.plot(x_axis, norm.pdf(x_axis,mu,np.exp(logsig)),label='approx-posterior')
plt.plot(x_axis, norm.pdf(x_axis,theta_prior_mean,theta_prior_std),c='k',alpha=.3,label='prior')
plt.axvline(x=true_theta,c='r',label='actual')
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
Not bad! ~~Bad~~
(delete as appropriate. It doesn't work for every theta in my experience)
%% Cell type:markdown id: tags:
Now let's use an SBI toolbox to do the same thing
%% Cell type:code id: tags:
``` python
import torch
import sbi.utils as utils
from sbi.inference.base import infer
# Same simulator as before
# (note: we did not try to infer the internal state parameter: noise_std)
def simulator(theta, noise_std=.1):
noise = noise_std * np.random.randn()
return theta**3 + noise
# Priors
theta_prior_mean = 0.
theta_prior_std = 1.
prior = torch.distributions.Normal(loc=torch.tensor([theta_prior_mean]),
scale=torch.tensor([theta_prior_std]))
# Below is SNPE-C method if I am not mistaken
# Feel free to check the documentation and try out different methods
# Simulator is user-defined function
# Prior must be a ...
posterior = infer(simulator, prior, method='SNPE', num_simulations=1000,num_workers=8)
```
%% Cell type:markdown id: tags:
### What have we done? Same thing as before! We have trained a neural net to map from data to posterior
Let us now see what happens when we have a new observation
%% Cell type:code id: tags:
``` python
# Generate an observation
true_theta = np.random.randn()*theta_prior_std+theta_prior_mean
observation = simulator(true_theta)
# Sample from the *now amortised* posterior
samples = np.asarray(posterior.sample((1000,), x=observation))
mu = np.mean(samples)
sig = np.std((samples))
# Plot
from scipy.stats import norm
x_axis = np.linspace(-2, 2,1000)
plt.plot (x_axis, norm.pdf(x_axis,mu,sig),label='approx-posterior')
plt.plot (x_axis, norm.pdf(x_axis,theta_prior_mean,theta_prior_std),c='k',alpha=.3,label='prior')
plt.axvline (x=true_theta,c='r',label='actual')
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags:
Soup-Herb!
%% Cell type:markdown id: tags:
Now let us use this SBI toolbox on a more interesting simulator: the ball and stick model
%% Cell type:code id: tags:
``` python
import torch
import numpy as np
import sbi.utils as utils
from sbi.inference.base import infer
# read in bvals and bvecs
bvals = torch.tensor(np.loadtxt('/Users/saad/data/dtidata/bvals'),dtype=float)/1000.0
bvecs = torch.tensor(np.loadtxt('/Users/saad/data/dtidata/bvecs'),dtype=float)
bvals = torch.tensor(np.loadtxt('./data/bvals'),dtype=float)/1000.0
bvecs = torch.tensor(np.loadtxt('./data/bvecs'),dtype=float)
# forward model is ball and stick
def forward(p):
s0,d,th,ph,f = p
# spherical->cartesian
x = torch.tensor([torch.sin(th)*torch.cos(ph),torch.sin(th)*torch.sin(ph),torch.cos(th)])
# sqr dot prod of fibre orientation and bvec
ang = torch.sum((x[:,None]*bvecs)**2,axis=0)
# return predicted signal
return s0*((1-f)*torch.exp(-bvals*d)+f*(torch.exp(-bvals*d*ang)))
# simulator generates noisy data
def simulator(parameter_set,noise_std = .01):
noise = torch.randn(bvals.shape) * noise_std
return forward(parameter_set) + noise
# Priors (uniform in this case)
# s0 d th ph f
prior = utils.BoxUniform(low = torch.tensor([0., 0., -10., -10., 0.]),
high = torch.tensor([2., 3., 10., 10., 1.]))
```
%% Cell type:code id: tags:
``` python
# Before we see any data, train neural net
posterior = infer(simulator, prior, method='SNPE', num_simulations=1000,num_workers=8)
```
%% Cell type:code id: tags:
``` python
# Generate a new observation
true_p = torch.tensor([0.7000, 0.3000, 0.7854, 1.0472, 0.3000])
observation = simulator(true_p)
plt.plot(observation)
```
%% Cell type:code id: tags:
``` python
# Sample from posterior
samples = posterior.sample((10000,), x=observation)
```
%% Cell type:markdown id: tags:
### Now lett us plot the posterior distribution
Because we have multiple paraameters, it is useful to not only plot the marginal for each param, butt also the joint posteriors for pairs of params. Thee SBI toolbox has a useful utility for this.
%% Cell type:code id: tags:
``` python
from sbi import utils as utils
fig, axes = utils.pairplot(samples,
fig_size=(15,15),
points=true_p,labels=['s0','d','th','ph','f'],
points_offdiag={'markersize': 16},
points_colors='r');
```
%% Cell type:markdown id: tags:
You should be able to see the bimodal distribution for phi (as the model doesn't care about the sign of phi). You should also note how incredibly fast it is to generate samples compared to MCMC.
Also the results here my appear disappointing but the prior for theta/phi is very bad, as it is not uniform on the sphere. Nevermind...
*The End.*
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
......
......@@ -213,8 +213,8 @@ import sbi.utils as utils
from sbi.inference.base import infer
# read in bvals and bvecs
bvals = torch.tensor(np.loadtxt('/Users/saad/data/dtidata/bvals'),dtype=float)/1000.0
bvecs = torch.tensor(np.loadtxt('/Users/saad/data/dtidata/bvecs'),dtype=float)
bvals = torch.tensor(np.loadtxt('./data/bvals'),dtype=float)/1000.0
bvecs = torch.tensor(np.loadtxt('./data/bvecs'),dtype=float)
# forward model is ball and stick
def forward(p):
......
0.0000000e+00 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 0.0000000e+00 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03 1.0000000e+03
0.0000000e+00 7.6937618e-01 9.9999924e-01 4.0214920e-02 1.9526217e-01 5.1096926e-01 6.4449110e-01 1.1639260e-01 4.1472436e-01 6.8663556e-01 3.8949705e-01 5.0524772e-01 1.1717751e-01 7.6094012e-01 7.8334360e-01 4.4028061e-01 2.0589133e-01 1.3436626e-01 2.7000378e-01 4.4543086e-01 2.1816264e-01 3.6644717e-01 1.9651463e-01 5.6797954e-01 4.2364378e-01 4.8425159e-01 4.8837442e-01 7.1876309e-01 9.3607570e-02 1.7722160e-01 0.0000000e+00 5.1711309e-01 7.5808487e-01 8.0234240e-01 1.9735010e-02 3.3516758e-01 6.3448454e-01 8.1168960e-02 1.0445113e-01 2.6766319e-01 6.7016636e-01 1.5391958e-01 4.1367573e-01 8.8398617e-01 2.1753901e-01 9.3271228e-01 6.9488384e-01 8.8211118e-01 6.9509295e-01 8.7347998e-01 7.5283496e-01 8.8371946e-01 8.8840244e-01 5.5232725e-01 9.3679074e-01 9.4961311e-01 6.4277964e-01 3.4381087e-01 9.9971116e-01 9.9821503e-01 4.6606173e-01 1.3874523e-01
0.0000000e+00 6.6306062e-01 1.6707810e-02 -9.9903292e-01 7.6591725e-01 7.7929365e-01 -7.8747432e-01 -9.9961185e-01 -8.8677651e-01 3.4283814e-01 -4.8693386e-01 8.8074278e-01 -8.8352002e-01 -7.8999910e-02 2.6876947e-01 2.5290552e-01 9.3262713e-01 -3.3491810e-01 3.2593888e-01 7.5894353e-01 -2.6638000e-03 -9.4866592e-01 9.9998147e-01 -7.8726452e-01 -7.2459950e-01 -6.4709189e-01 -1.7825689e-01 5.5582022e-01 8.8285078e-01 4.8544349e-01 0.0000000e+00 4.7983043e-01 -4.9330514e-01 -6.0886478e-01 -4.8934133e-01 9.3013284e-01 -3.1901719e-01 -1.8293329e-01 -6.6242333e-01 -8.8671109e-01 -4.2004243e-01 -7.2580995e-01 -9.4553600e-02 4.6713278e-01 6.6282962e-01 3.7641775e-01 -6.4060372e-01 -4.9143385e-01 1.4877920e-02 -1.6054820e-01 5.4879866e-01 -2.8861432e-01 1.8916095e-01 1.5049253e-01 6.7304890e-02 -3.1584415e-01 7.6457665e-01 -4.3058758e-01 -1.6450303e-01 1.8210187e-01 5.8200226e-01 1.4422617e-01
0.0000000e+00 -9.1990930e-02 -1.9943366e-01 -2.0091641e-01 -6.4443377e-01 -4.1432263e-01 6.7578430e-02 1.6516763e-01 2.8580159e-01 -6.7159887e-01 8.0699374e-01 -9.5227750e-02 4.9569470e-01 6.7437421e-01 5.9513629e-01 -8.8444514e-01 -3.5758470e-01 9.5384852e-01 9.2785410e-01 5.1541002e-01 9.9621754e-01 -7.6163460e-02 -3.8346860e-02 -3.1251438e-01 5.7925205e-01 -6.2195169e-01 8.7736329e-01 4.6313598e-01 5.0186008e-01 -8.7919767e-01 0.0000000e+00 7.3648370e-01 -4.7117809e-01 -1.5994303e-01 -8.9454291e-01 2.5013402e-01 -7.3192133e-01 -9.9999922e-01 7.6833373e-01 -4.2679184e-01 6.4381118e-01 -6.9968571e-01 -9.2735292e-01 2.0101452e-01 7.4389858e-01 -1.6854941e-01 3.8316412e-01 1.4291280e-01 -7.4610714e-01 -5.0130672e-01 -4.1486259e-01 4.1927689e-01 -4.6369242e-01 8.4400121e-01 3.9742244e-01 -1.9628767e-01 2.0568982e-01 -8.5816073e-01 1.1648077e-01 1.0225989e-01 -6.9578104e-01 -1.0000000e+00
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment