{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 223,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x125da2160>]"
      ]
     },
     "execution_count": 223,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Fit a model to some data\n",
    "# Model is:\n",
    "#    prediction = M0 * exp(-R2*TE)*(1-exp(-R1*TR))\n",
    "#    where M0,R1,R2 are unknown parameters and TE/TR are experimental parameters\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import minimize\n",
    "\n",
    "\n",
    "TEs = np.array([10,40,50,60,80]) # TE values in ms\n",
    "TRs = np.array([.8,1,1.5,2])     # TR in seconds (I know this is bad)\n",
    "\n",
    "# All combinations of TEs/TRs\n",
    "comb    = np.array([(x,y) for x in TEs for y in TRs])\n",
    "TEs,TRs = comb[:,0],comb[:,1]\n",
    "\n",
    "# function for our model\n",
    "def forward(p):\n",
    "    M0,R1,R2 = p\n",
    "    return M0*np.exp(-R2*TEs)*(1-np.exp(-R1*TRs))\n",
    "\n",
    "# simulate data using model \n",
    "true_p    = [100,1/.8,1/50]\n",
    "data      = forward(true_p)\n",
    "snr       = 50\n",
    "noise_std = 100/snr\n",
    "noise     = np.random.randn(data.size)*noise_std\n",
    "data      = data + noise\n",
    "\n",
    "plt.plot(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now for the fitting\n",
    "# we need a cost function:\n",
    "\n",
    " \n",
    "# always a good idea to calculate gradient\n",
    "def forward_deriv(p):\n",
    "    M0,R1,R2 = p\n",
    "    E1,E2    = np.exp(-R1*TRs),np.exp(-R2*TEs)\n",
    "    dE1      = -TRs*E1\n",
    "    dE2      = -TEs*E2\n",
    "    \n",
    "    # f = M0*E2*(1-E1)\n",
    "    dfdM0 = E2*(1-E1)\n",
    "    dfdR1 = M0*E2*(-dE1)\n",
    "    dfdR2 = M0*dE2*(1-E1)\n",
    "    return np.array([dfdM0,dfdR1,dfdR2])\n",
    "\n",
    "def forward_deriv2(p):\n",
    "    M0,R1,R2 = p\n",
    "    E1,E2    = np.exp(-R1*TRs),np.exp(-R2*TEs)\n",
    "    dE1      = -TRs*E1\n",
    "    dE2      = -TEs*E2\n",
    "    ddE1     = (TRs**2)*E1\n",
    "    ddE2     = (TEs**2)*E2\n",
    "    \n",
    "    #dfdM0 = E2*(1-E1)\n",
    "    dfdM0dM0 = np.zeros(E1.shape)\n",
    "    dfdM0dR1 = E2*(-dE1)\n",
    "    dfdM0dR2 = dE2*(1-E1)\n",
    "\n",
    "    #dfdR1 = M0*E2*(-dE1)\n",
    "    dfdR1dM0 = E2*(-dE1)\n",
    "    dfdR1dR1 = M0*E2*(-ddE1)\n",
    "    dfdR1dR2 = M0*(dE2)*(-dE1)\n",
    " \n",
    "    #dfdR2 = M0*dE2*(1-E1)\n",
    "    dfdR2dM0 = dE2*(1-E1)\n",
    "    dfdR2dR1 = M0*dE2*(-dE1)\n",
    "    dfdR2dR2 = M0*ddE2*(1-E1)\n",
    "\n",
    "    return np.array([[dfdM0dM0,dfdM0dR1,dfdM0dR2],\n",
    "                     [dfdR1dM0,dfdR1dR1,dfdR1dR2],\n",
    "                     [dfdR2dM0,dfdR2dR1,dfdR2dR2]])\n",
    "\n",
    "\n",
    "# cost function is mean square error divided by 2\n",
    "def cf(p):\n",
    "    pred = forward(p)\n",
    "    return np.mean((pred-data)**2)/2.0\n",
    "\n",
    "def cf_grad(p):\n",
    "    pred  = forward(p)\n",
    "    deriv = forward_deriv(p)\n",
    "    return np.mean( deriv * (pred-data)[None,:],axis=1)\n",
    "\n",
    "def cf_hess(p):\n",
    "    pred   = forward(p)\n",
    "    deriv  = forward_deriv(p)\n",
    "    deriv2 = forward_deriv2(p)\n",
    "    \n",
    "    H = np.zeros((len(p),len(p)))\n",
    "    for i in range(H.shape[0]):\n",
    "        for j in range(H.shape[1]):\n",
    "            H[i,j] = np.mean(deriv2[i,j]*(pred-data) + deriv[i]*deriv[j])\n",
    "    return H\n",
    "    \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fitted = [9.90012687e+01 1.28692542e+00 1.97590584e-02]\n",
      "true   = [100, 1.25, 0.02]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# get ready to minimize\n",
    "p0 = [200,1/1,1/70] # some random guess\n",
    "method = 'trust-ncg'\n",
    "\n",
    "kw_args = {'x0':p0,'method':method,'jac':cf_grad,'hess':cf_hess}\n",
    "\n",
    "result = minimize(cf,**kw_args)\n",
    "\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(data,'o')\n",
    "plt.plot(forward(result.x))\n",
    "print('fitted = {}'.format(result.x))\n",
    "print('true   = {}'.format(true_p))\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}