Skip to content
Snippets Groups Projects
fit_model.ipynb 54.2 KiB
Newer Older
Saad Jbabdi's avatar
Saad Jbabdi committed
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Short example of Model fitting\n",
    "\n",
    "Here we fit a MRI-style model to some simulated data.\n",
    "\n",
    "The model is: $\\textrm{Signal} = M_0\\exp\\left[-R_2\\textrm{TE}\\right]\\left(1-\\exp\\left[-R_1\\textrm{TI}\\right]\\right)$ \n",
    "\n",
    "The parameters that we will be fitting are $(M_0,R_1,R_2)$. \n",
    "\n",
    "Basic imports:\n"
   ]
  },
Saad Jbabdi's avatar
Saad Jbabdi committed
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import minimize\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section:\n",
    "\n",
    "- defining a numpy array\n",
    "- double list comprehension\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEs = np.array([10,40,50,60,80]) # TE values in ms\n",
    "TRs = np.array([.8,1,1.5,2])     # TR in seconds (I know this is bad)\n",
    "\n",
    "# All combinations of TEs/TRs\n",
    "comb    = np.array([(x,y) for x in TEs for y in TRs])\n",
    "TEs,TRs = comb[:,0],comb[:,1]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Now we define our forward model\n",
    "\n",
    "In this section:\n",
    "\n",
    "- inline function definition\n",
    "- random number generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
Saad Jbabdi's avatar
Saad Jbabdi committed
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x121b13978>]"
Saad Jbabdi's avatar
Saad Jbabdi committed
      ]
     },
     "execution_count": 4,
Saad Jbabdi's avatar
Saad Jbabdi committed
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
Saad Jbabdi's avatar
Saad Jbabdi committed
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "# function for our model\n",
    "def forward(p):\n",
    "    M0,R1,R2 = p\n",
    "    return M0*np.exp(-R2*TEs)*(1-np.exp(-R1*TRs))\n",
    "\n",
    "# simulate data using model \n",
    "true_p    = [100,1/.8,1/50]   # M0, R1=1/T1,R2=1/T2\n",
Saad Jbabdi's avatar
Saad Jbabdi committed
    "data      = forward(true_p)\n",
    "snr       = 50\n",
    "noise_std = true_p[0]/snr\n",
Saad Jbabdi's avatar
Saad Jbabdi committed
    "noise     = np.random.randn(data.size)*noise_std\n",
    "data      = data + noise\n",
    "\n",
    "# quick plot of the data\n",
    "plt.figure()\n",
Saad Jbabdi's avatar
Saad Jbabdi committed
    "plt.plot(data)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have the data and our forward model we are almost ready to begin fitting.\n",
    "\n",
    "We need a cost function to optimise. We will use mean squared error. \n",
    "\n",
    "In this section:\n",
    "\n",
    "- '**' operation\n",
    "- np.mean"
   ]
  },
Saad Jbabdi's avatar
Saad Jbabdi committed
  {
   "cell_type": "code",
   "execution_count": 5,
Saad Jbabdi's avatar
Saad Jbabdi committed
   "metadata": {},
   "outputs": [],
   "source": [
    "# cost function is mean square error divided by 2\n",
    "def cf(p):\n",
    "    pred = forward(p)\n",
    "    return np.mean((pred-data)**2)/2.0\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can set up our optimisation.\n",
Saad Jbabdi's avatar
Saad Jbabdi committed
    "\n",
    "In this section:\n",
    "\n",
    "- scipy minimize\n",
    "- dictionary\n",
    "- keyword arguments\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get ready to minimize\n",
    "p0 = [200,1/1,1/70]  # some random initial guess\n",
    "method = 'powell' # pick a method. scipy has loads!\n",
    "\n",
    "\n",
    "kw_args = {'x0':p0,'method':method}\n",
    "result  = minimize(cf,**kw_args)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot the data with the model prediction.\n",
    "\n",
    "In this section\n",
    "\n",
    "- printing\n",
    "- text formatting\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fitted = [1.02562035e+02 1.16194491e+00 1.98071179e-02]\n",
      "true   = [100, 1.25, 0.02]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.plot(data,'o')\n",
    "plt.plot(forward(result.x))\n",
    "print('fitted = {}'.format(result.x))\n",
    "print('true   = {}'.format(true_p))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optional: use gradients and hessians to help with the optimisation\n",
    "\n",
    "In this example the forward model is simple enough that calculating the derivatives of the cost function is relatively easy to do analytically. Below is an example of how you could define these and use them in the fitting\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# gradient of the forward model\n",
Saad Jbabdi's avatar
Saad Jbabdi committed
    "def forward_deriv(p):\n",
    "    M0,R1,R2 = p\n",
    "    E1,E2    = np.exp(-R1*TRs),np.exp(-R2*TEs)\n",
    "    dE1      = -TRs*E1\n",
    "    dE2      = -TEs*E2\n",
    "    \n",
    "    # f = M0*E2*(1-E1)\n",
    "    dfdM0 = E2*(1-E1)\n",
    "    dfdR1 = M0*E2*(-dE1)\n",
    "    dfdR2 = M0*dE2*(1-E1)\n",
    "    return np.array([dfdM0,dfdR1,dfdR2])\n",
    "\n",
    "# hessian of the forward model\n",
Saad Jbabdi's avatar
Saad Jbabdi committed
    "def forward_deriv2(p):\n",
    "    M0,R1,R2 = p\n",
    "    E1,E2    = np.exp(-R1*TRs),np.exp(-R2*TEs)\n",
    "    dE1      = -TRs*E1\n",
    "    dE2      = -TEs*E2\n",
    "    ddE1     = (TRs**2)*E1\n",
    "    ddE2     = (TEs**2)*E2\n",
    "    \n",
    "    dfdM0dM0 = np.zeros(E1.shape)\n",
    "    dfdM0dR1 = E2*(-dE1)\n",
    "    dfdM0dR2 = dE2*(1-E1)\n",
    "\n",
    "    dfdR1dM0 = E2*(-dE1)\n",
    "    dfdR1dR1 = M0*E2*(-ddE1)\n",
    "    dfdR1dR2 = M0*(dE2)*(-dE1)\n",
    " \n",
    "    dfdR2dM0 = dE2*(1-E1)\n",
    "    dfdR2dR1 = M0*dE2*(-dE1)\n",
    "    dfdR2dR2 = M0*ddE2*(1-E1)\n",
    "\n",
    "    return np.array([[dfdM0dM0,dfdM0dR1,dfdM0dR2],\n",
    "                     [dfdR1dM0,dfdR1dR1,dfdR1dR2],\n",
    "                     [dfdR2dM0,dfdR2dR1,dfdR2dR2]])\n",
    "\n",
    "\n",
    "# cost function is mean square error divided by 2\n",
    "def cf(p):\n",
    "    pred = forward(p)\n",
    "    return np.mean((pred-data)**2)/2.0\n",
    "\n",
    "def cf_grad(p):\n",
    "    pred  = forward(p)\n",
    "    deriv = forward_deriv(p)\n",
    "    return np.mean( deriv * (pred-data)[None,:],axis=1)\n",
    "\n",
    "def cf_hess(p):\n",
    "    pred   = forward(p)\n",
    "    deriv  = forward_deriv(p)\n",
    "    deriv2 = forward_deriv2(p)\n",
    "    \n",
    "    H = np.zeros((len(p),len(p)))\n",
    "    for i in range(H.shape[0]):\n",
    "        for j in range(H.shape[1]):\n",
    "            H[i,j] = np.mean(deriv2[i,j]*(pred-data) + deriv[i]*deriv[j])\n",
    "    return H\n",
    "    \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
Saad Jbabdi's avatar
Saad Jbabdi committed
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fitted = [1.02576306e+02 1.16153921e+00 1.98044006e-02]\n",
Saad Jbabdi's avatar
Saad Jbabdi committed
      "true   = [100, 1.25, 0.02]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
Saad Jbabdi's avatar
Saad Jbabdi committed
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
Saad Jbabdi's avatar
Saad Jbabdi committed
    "# get ready to minimize\n",
    "p0 = [200,1/1,1/70] # some random guess\n",
    "method = 'trust-ncg'\n",
    "\n",
    "kw_args = {'x0':p0,'method':method,'jac':cf_grad,'hess':cf_hess}\n",
    "\n",
    "result = minimize(cf,**kw_args)\n",
    "\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(data,'o')\n",
    "plt.plot(forward(result.x))\n",
    "print('fitted = {}'.format(result.x))\n",
    "print('true   = {}'.format(true_p))\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
Saad Jbabdi's avatar
Saad Jbabdi committed
   "metadata": {},
Saad Jbabdi's avatar
Saad Jbabdi committed
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}