From 54c1373b51c264aa4d882f0eb2bbd8c938622382 Mon Sep 17 00:00:00 2001
From: Mark Chiew <mchiew@fmrib.ox.ac.uk>
Date: Fri, 6 Mar 2020 23:04:57 +0000
Subject: [PATCH] Updated bloch/partial_fourier, more comments

---
 talks/matlab_vs_python/bloch/bloch.ipynb      | 168 +++++++++++++++---
 talks/matlab_vs_python/bloch/bloch.m          |  91 +++++++---
 talks/matlab_vs_python/bloch/bloch.py         |  40 -----
 .../partial_fourier/partial_fourier.ipynb     | 152 ++++++++++++++--
 .../partial_fourier/partial_fourier.m         |  76 ++++++--
 .../partial_fourier/partial_fourier.py        |  32 ----
 6 files changed, 412 insertions(+), 147 deletions(-)
 delete mode 100644 talks/matlab_vs_python/bloch/bloch.py
 delete mode 100644 talks/matlab_vs_python/partial_fourier/partial_fourier.py

diff --git a/talks/matlab_vs_python/bloch/bloch.ipynb b/talks/matlab_vs_python/bloch/bloch.ipynb
index a96717a..37d4b8c 100644
--- a/talks/matlab_vs_python/bloch/bloch.ipynb
+++ b/talks/matlab_vs_python/bloch/bloch.ipynb
@@ -2,7 +2,9 @@
  "cells": [
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
     "Imports"
    ]
@@ -20,9 +22,17 @@
   },
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "Define bloch and B_eff functions"
+    "# Define the Bloch equation\n",
+    "\n",
+    "$$\\frac{d\\vec{M}}{dt} = \\vec{M}\\times \\vec{B} - \\frac{M_x + M_y}{T2} - \\frac{M_z - M_0}{T1}$$\n",
+    "\n",
+    "In this section:\n",
+    "- define a function\n",
+    "- numpy functions like np.cross"
    ]
   },
   {
@@ -31,30 +41,78 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "def bloch_ode(t,M,T1,T2):\n",
+    "# define bloch equation\n",
+    "def bloch_ode(t, M, T1, T2):\n",
+    "    # get effective B-field at time t\n",
     "    B = B_eff(t)\n",
+    "    # cross product of M and B, add T1 and T2 relaxation terms\n",
     "    return np.array([M[1]*B[2] - M[2]*B[1] - M[0]/T2,\n",
     "                     M[2]*B[0] - M[0]*B[2] - M[1]/T2,\n",
     "                     M[0]*B[1] - M[1]*B[0] - (M[2]-1)/T1])\n",
+    "    # alternatively\n",
+    "    #return np.cross(M,B) - np.array([M[0]/T2, M[1]/T2, (M[2]-1)/T1])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define the pulse sequence \n",
+    "\n",
+    "We work in the rotating frame, so we only need the amplitude envelope of the RF pulse\n",
+    "\n",
+    "Typically, B1 excitation fields point in the x- and/or y-directions  \n",
+    "Gradient fields point in the z-direction\n",
+    "\n",
+    "In this simple example, a simple sinc-pulse excitation pulse is applied for 1 ms along the x-axis  \n",
+    "then a gradient is turned on for 1.5 ms after that.\n",
     "\n",
+    "In this section:\n",
+    "- constants such as np.pi\n",
+    "- functions like np.sinc"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# define effective B-field\n",
     "def B_eff(t):\n",
+    "    # Do nothing for 0.25 ms\n",
     "    if t < 0.25:\n",
     "        return np.array([0, 0, 0])\n",
+    "    # Sinc RF along x-axis and slice-select gradient on for 1.00 ms\n",
     "    elif t < 1.25:\n",
-    "        return np.array([1.8*np.sinc(t-0.75), 0, 0])\n",
+    "        return np.array([np.pi*np.sinc((t-0.75)*4), 0, np.pi])\n",
+    "    # Do nothing for 0.25 ms\n",
     "    elif t < 1.50:\n",
     "        return np.array([0, 0, 0])\n",
+    "    # Slice refocusing gradient on for 1.50 ms\n",
+    "    # Half the area of the slice-select gradient lobe\n",
     "    elif t < 3.00:\n",
-    "        return np.array([0, 0, 2*np.pi])\n",
+    "        return np.array([0, 0, -(1/3)*np.pi])\n",
+    "    # Pulse sequence finished\n",
     "    else:\n",
     "        return np.array([0, 0, 0])"
    ]
   },
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "Integrate ODE"
+    "# Plot the pulse sequence\n",
+    "\n",
+    "In this section:\n",
+    "- unpacking return values\n",
+    "- unwanted return values\n",
+    "- list comprehension\n",
+    "- basic plotting"
    ]
   },
   {
@@ -63,23 +121,36 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "t = np.array([0])\n",
-    "M = np.array([[0, 0, 1]])\n",
-    "dt= 0.005\n",
-    "r = ode(bloch_ode)\n",
-    "r.set_integrator('dopri5')\n",
-    "r.set_initial_value(M[0],t[0])\n",
-    "r.set_f_params(1500, 50)\n",
-    "while r.successful() and r.t < 5:\n",
-    "    t = np.append(t,r.t+dt)\n",
-    "    M = np.append(M, np.array([r.integrate(t[-1])]),axis=0)"
+    "# Create 2 vertical subplots\n",
+    "_, ax = plt.subplots(2, 1, figsize=(12,12))\n",
+    "\n",
+    "# Get pulse sequence B-fields from 0 - 5 ms\n",
+    "pulseq = [B_eff(t) for t in np.linspace(0,5,1000)]\n",
+    "pulseq = np.array(pulseq)\n",
+    "\n",
+    "# Plot RF\n",
+    "ax[0].plot(pulseq[:,0])\n",
+    "ax[0].set_ylabel('B1')\n",
+    "\n",
+    "# Plot gradient\n",
+    "ax[1].plot(pulseq[:,2])\n",
+    "ax[1].set_ylabel('Gradient')"
    ]
   },
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "Plot Results"
+    "# Integrate ODE  \n",
+    "\n",
+    "This uses a Runge-Kutta variant called the \"Dormand-Prince method\"\n",
+    "\n",
+    "In this section:\n",
+    "- list of arrays\n",
+    "- ode solvers\n",
+    "- list appending"
    ]
   },
   {
@@ -88,11 +159,48 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "_, ax = plt.subplots(figsize=(12,12))\n",
-    "ax.plot(t,M[:,0], label='Mx')\n",
-    "ax.plot(t,M[:,1], label='My')\n",
-    "ax.plot(t,M[:,2], label='Mz')\n",
-    "ax.legend()"
+    "# Set the initial conditions\n",
+    "# time (t) = 0\n",
+    "# equilibrium magnetization (M) = (0, 0, 1)\n",
+    "t = [0]\n",
+    "M = [np.array([0, 0, 1])]\n",
+    "\n",
+    "# Set integrator time-step\n",
+    "dt= 0.005\n",
+    "\n",
+    "# Set up ODE integrator object\n",
+    "r = ode(bloch_ode)\n",
+    "\n",
+    "# Choose the integrator method\n",
+    "r.set_integrator('dopri5')\n",
+    "\n",
+    "# Pass in initial values\n",
+    "r.set_initial_value(M[0], t[0])\n",
+    "\n",
+    "# Set T1 and T2\n",
+    "T1, T2 = 1500, 50\n",
+    "r.set_f_params(T1, T2)\n",
+    "\n",
+    "# Integrate by looping over time, moving dt by step size each iteration\n",
+    "# Append new time point and Magnetisation vector at every step to t and M\n",
+    "while r.successful() and r.t < 5:\n",
+    "    t.append(r.t + dt)\n",
+    "    M.append(r.integrate(t[-1]))\n",
+    "\n",
+    "# Convert M to 2-D numpy array from list of arrays\n",
+    "M = np.array(M)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Plot Results\n",
+    "\n",
+    "In this section:\n",
+    "- more plotting"
    ]
   },
   {
@@ -100,7 +208,19 @@
    "execution_count": null,
    "metadata": {},
    "outputs": [],
-   "source": []
+   "source": [
+    "# Create single axis\n",
+    "_, ax = plt.subplots(figsize=(12,12))\n",
+    "\n",
+    "# Plot x, y and z components of Magnetisation\n",
+    "ax.plot(t, M[:,0], label='Mx')\n",
+    "ax.plot(t, M[:,1], label='My')\n",
+    "ax.plot(t, M[:,2], label='Mz')\n",
+    "\n",
+    "# Add legend and grid\n",
+    "ax.legend()\n",
+    "ax.grid()"
+   ]
   }
  ],
  "metadata": {
@@ -119,9 +239,9 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.7.3"
+   "version": "3.7.3-final"
   }
  },
  "nbformat": 4,
  "nbformat_minor": 4
-}
+}
\ No newline at end of file
diff --git a/talks/matlab_vs_python/bloch/bloch.m b/talks/matlab_vs_python/bloch/bloch.m
index 324fcc8..64fdc0a 100644
--- a/talks/matlab_vs_python/bloch/bloch.m
+++ b/talks/matlab_vs_python/bloch/bloch.m
@@ -1,32 +1,73 @@
-%% Imports
-
-%% Integrate ODE
-[t, M] = ode45(@(t,M)bloch_ode(t,M,1500,50),linspace(0,5,1000),[0;0;1]);
-
-%% Plot Results
-clf();hold on;
-plot(t,M(:,1));
-plot(t,M(:,2));
-plot(t,M(:,3));
-
-%% Define bloch and b_eff functions
-function dM = bloch_ode(t,M,T1,T2)
-    B   =   B_eff(t);                               % B-effective
-    dM  =  [M(2)*B(3) - M(3)*B(2) - M(1)/T2;        % dMx/dt
-            M(3)*B(1) - M(1)*B(3) - M(2)/T2;        % dMy/dt
-            M(1)*B(2) - M(2)*B(1) - (M(3)-1)/T1];   % dMz/dt
+% Plot the pulse sequence
+% create figure
+figure(); 
+
+% get pulse sequence B-fields from 0-5 ms
+pulseq = zeros(1000,3);
+for i = 1:1000
+    pulseq(i,:) = B_eff(i*0.005); 
 end
 
+% plot RF
+subplot(2,1,1);
+plot(pulseq(:,1));
+ylabel('B1');
+
+% plot gradient
+subplot(2,1,2);
+plot(pulseq(:,3));
+ylabel('Gradient');
+
+% Integrate ODE
+T1 = 1500;
+T2 = 50;
+t0 = 0;
+t1 = 5;
+dt = 0.005;
+M0 = [0; 0; 1];
+[t, M] = ode45(@(t,M)bloch_ode(t, M, T1, T2), linspace(t0, t1, (t1-t0)/dt), M0);
+
+% Plot Results
+% create figure
+figure();hold on;
+
+% plot x, y and z components of Magnetisation
+plot(t, M(:,1));
+plot(t, M(:,2));
+plot(t, M(:,3));
+
+% add legend and grid
+legend({'Mx','My','Mz'});
+grid on;
+
+% define the bloch equation
+function dM = bloch_ode(t, M, T1, T2)
+    % get effective B-field at time t
+    B   =   B_eff(t);                               
+
+    % cross product of M and B, add T1 and T2 relaxation terms
+    dM  =  [M(2)*B(3) - M(3)*B(2) - M(1)/T2;        
+            M(3)*B(1) - M(1)*B(3) - M(2)/T2;        
+            M(1)*B(2) - M(2)*B(1) - (M(3)-1)/T1];   
+end
+
+% define effective B-field
 function b = B_eff(t)
-    if t < 0.25                 % No B-field
+    % Do nothing for 0.25 ms
+    if t < 0.25             
         b = [0, 0, 0];
-    elseif t < 1.25             % 1-ms excitation around x-axis
-        b = [1.8*sinc(t-0.75), 0, 0];
-    elseif t < 1.50             % No B-field
+    % Sinc RF along x-axis and slice-select gradient on for 1.00 ms
+    elseif t < 1.25             
+        b = [pi*sinc((t-0.75)*4), 0, pi];
+    % Do nothing for 0.25 ms
+    elseif t < 1.50             
         b = [0, 0, 0];
-    elseif t < 3.00             % Gradient in y-direction
-        b = [0, 0, 2*pi];
-    else                        % No B-field
+    % Slice refocusing gradient on for 1.50 ms
+    % Half the area of the slice-select gradient lobe
+    elseif t < 3.00             
+        b = [0, 0, -(1/3)*pi];
+    % pulse sequence finished
+    else                       
         b = [0, 0, 0];
     end
-end
\ No newline at end of file
+end
diff --git a/talks/matlab_vs_python/bloch/bloch.py b/talks/matlab_vs_python/bloch/bloch.py
deleted file mode 100644
index 873798d..0000000
--- a/talks/matlab_vs_python/bloch/bloch.py
+++ /dev/null
@@ -1,40 +0,0 @@
-#%% Imports
-import numpy as np
-from scipy.integrate import ode
-import matplotlib.pyplot as plt
-
-#%% Define bloch and B_eff functions
-def bloch_ode(t,M,T1,T2):
-    B = B_eff(t)
-    return np.array([M[1]*B[2] - M[2]*B[1] - M[0]/T2,
-                     M[2]*B[0] - M[0]*B[2] - M[1]/T2,
-                     M[0]*B[1] - M[1]*B[0] - (M[2]-1)/T1])
-
-def B_eff(t):
-    if t < 0.25:
-        return np.array([0, 0, 0])
-    elif t < 1.25:
-        return np.array([1.8*np.sinc(t-0.75), 0, 0])
-    elif t < 1.50:
-        return np.array([0, 0, 0])
-    elif t < 3.00:
-        return np.array([0, 0, 2*np.pi])
-    else:
-        return np.array([0, 0, 0])
-
-#%% Integrate ODE
-t = np.array([0])
-M = np.array([[0, 0, 1]])
-dt= 0.005
-r = ode(bloch_ode)\
-   .set_integrator('dopri5')\
-   .set_initial_value(M[0],t[0])\
-   .set_f_params(1500, 50)
-while r.successful() and r.t < 5:
-    t = np.append(t,r.t+dt)
-    M = np.append(M, np.array([r.integrate(t[-1])]),axis=0)
-
-#%% Plot Results
-plt.plot(t,M[:,0])
-plt.plot(t,M[:,1])
-plt.plot(t,M[:,2])
\ No newline at end of file
diff --git a/talks/matlab_vs_python/partial_fourier/partial_fourier.ipynb b/talks/matlab_vs_python/partial_fourier/partial_fourier.ipynb
index ea2bb3e..0a91899 100644
--- a/talks/matlab_vs_python/partial_fourier/partial_fourier.ipynb
+++ b/talks/matlab_vs_python/partial_fourier/partial_fourier.ipynb
@@ -2,7 +2,9 @@
  "cells": [
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
     "Imports"
    ]
@@ -20,9 +22,20 @@
   },
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "Load Data"
+    "# Load data\n",
+    "\n",
+    "Load complex image data from MATLAB mat-file (v7.3 or later), which is actually an HDF5 format\n",
+    "\n",
+    "Complex data is loaded as a (real, imag) tuple, so it neds to be explicitly converted to complex double\n",
+    "\n",
+    "In this section:\n",
+    "- using h5py module\n",
+    "- np.transpose\n",
+    "- 1j as imaginary constant\n"
    ]
   },
   {
@@ -31,15 +44,33 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# get hdf5 object for the mat-file\n",
     "h = h5py.File('data.mat','r')\n",
-    "img = np.transpose(h.get('img')['real']+1j*h.get('img')['imag'])"
+    "\n",
+    "# get img variable from the mat-file\n",
+    "dat = h.get('img')\n",
+    "\n",
+    "# turn array of (real, imag) tuples into an array of complex doubles\n",
+    "# transpose to keep data in same orientation as MATLAB\n",
+    "img = np.transpose(dat['real'] + 1j*dat['imag'])"
    ]
   },
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "6/8 Partial Fourier sampling"
+    "# 6/8 Partial Fourier sampling\n",
+    "\n",
+    "Fourier transform the image to get k-space data, and add complex Gaussian noise\n",
+    "\n",
+    "To simulate 6/8 Partial Fourier sampling, zero out the bottom 1/4 of k-space\n",
+    "\n",
+    "In this section:\n",
+    "- np.random.randn\n",
+    "- np.fft\n",
+    "- 0-based indexing"
    ]
   },
   {
@@ -48,16 +79,32 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# generate normally-distributed complex noise\n",
     "n = np.random.randn(96,96) + 1j*np.random.randn(96,96)\n",
-    "y = np.fft.fftshift(np.fft.fft2(img),axes=0) + n\n",
+    "\n",
+    "# Fourier transform the image and add noise\n",
+    "y = np.fft.fftshift(np.fft.fft2(img), axes=0) + n\n",
+    "\n",
+    "# set bottom 24/96 lines to 0\n",
     "y[72:,:] = 0"
    ]
   },
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "Estimate phase"
+    "# Estimate phase\n",
+    "\n",
+    "Filter the k-space data and extract a low-resolution phase estimate\n",
+    "\n",
+    "Filtering can help reduce ringing in the phase image\n",
+    "\n",
+    "In this section:\n",
+    "- np.pad\n",
+    "- np.hanning\n",
+    "- reshaping 1D array to 2D array using np.newaxis (or None)"
    ]
   },
   {
@@ -66,15 +113,43 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "pad = np.pad(np.hanning(48),24,'constant')[:,None]\n",
-    "phs = np.exp(1j*np.angle(np.fft.ifft2(np.fft.ifftshift(y*pad,axes=0))))"
+    "# create zero-padded hanning filter for ky-filtering\n",
+    "filt = np.pad(np.hanning(48),24,'constant')\n",
+    "\n",
+    "# reshape 1D array into 2D array\n",
+    "filt = filt[:,np.newaxis]\n",
+    "# or\n",
+    "# filt = filt[:,None]\n",
+    "\n",
+    "# generate low-res image with inverse Fourier transform\n",
+    "low = np.fft.ifft2(np.fft.ifftshift(y*filt, axes=0))\n",
+    "\n",
+    "# get phase image\n",
+    "phs = np.exp(1j*np.angle(low))"
    ]
   },
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "POCS reconstruction"
+    "# POCS reconstruction\n",
+    "\n",
+    "Perform the projection-onto-convex-sets (POCS) partial Fourier reconstruction method.\n",
+    "\n",
+    "POCS is an iterative scheme estimates the reconstructed image as any element in the intersection of the following two (convex) sets:\n",
+    "1. Set of images consistent with the measured data\n",
+    "2. Set of images that are non-negative real\n",
+    "\n",
+    "This requires prior knowledge of the image phase (hence the estimate above), and it works because although we have less than a full k-space of measurements, we're now only estimating half the number of free parameters (real values only, instead of real + imag), and we're no longer under-determined. Equivalently, consider the fact that real-valued images have conjugate symmetric k-spaces, so we only require half of k-space to reconstruct our image.\n",
+    "\n",
+    "In this section:\n",
+    "- np.zeros\n",
+    "- range() builtin\n",
+    "- point-wise multiplication (*)\n",
+    "- np.fft operations default to last axis, not first\n",
+    "- np.maximum vs np.max\n"
    ]
   },
   {
@@ -83,19 +158,52 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# initialise image estimate to be zeros\n",
     "est = np.zeros((96,96))\n",
+    "\n",
+    "# set the number of iterations \n",
     "iters = 10\n",
+    "\n",
+    "# each iteration cycles between projections\n",
     "for i in range(iters):\n",
-    "    est = np.fft.fftshift(np.fft.fft2(est*phs),axes=0)\n",
+    "# projection onto data-consistent set:\n",
+    "    # use a-priori phase to get complex image\n",
+    "    est = est*phs\n",
+    "    \n",
+    "    # Fourier transform to get k-space\n",
+    "    est = np.fft.fftshift(np.fft.fft2(est), axes=0)\n",
+    "    \n",
+    "    # replace data with measured lines\n",
     "    est[:72,:] = y[:72,:]\n",
-    "    est = np.maximum(np.real(np.fft.ifft2(np.fft.ifftshift(est,axes=0))*np.conj(phs)),0)"
+    "    \n",
+    "    # inverse Fourier transform to get back to image space\n",
+    "    est = np.fft.ifft2(np.fft.ifftshift(est, axes=0))\n",
+    "\n",
+    "# projection onto non-negative reals:\n",
+    "    # remove a-priori phase\n",
+    "    est = est*np.conj(phs)\n",
+    "    \n",
+    "    # get real part\n",
+    "    est = np.real(est)\n",
+    "\n",
+    "    # ensure output is non-negative\n",
+    "    est = np.maximum(est, 0)"
    ]
   },
   {
    "cell_type": "markdown",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "Plot reconstruction"
+    "# Display error and plot reconstruction\n",
+    "\n",
+    "The POCS reconstruction is compared to a zero-filled reconstruction (i.e., where the missing data is zeroed prior to inverse Fourier Transform)\n",
+    "\n",
+    "In this section:\n",
+    "- print formatted strings to standard output\n",
+    "- plotting, with min/max scales\n",
+    "- np.sum sums over all elements by default"
    ]
   },
   {
@@ -104,9 +212,25 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# compute zero-filled recon\n",
+    "zf = np.fft.ifft2(np.fft.ifftshift(y, axes=0))\n",
+    "\n",
+    "# compute rmse for zero-filled and POCS recon\n",
+    "err_zf = np.sqrt(np.sum(np.abs(zf - img)**2))\n",
+    "err_pocs = np.sqrt(np.sum(np.abs(est*phs - img)**2))\n",
+    "\n",
+    "# print errors\n",
+    "print(f'RMSE for zero-filled recon: {err_zf}')\n",
+    "print(f'RMSE for POCS recon: {err_pocs}')\n",
+    "\n",
+    "# plot both recons side-by-side\n",
     "_, ax = plt.subplots(1,2,figsize=(16,16))\n",
-    "ax[0].imshow(np.abs(np.fft.ifft2(np.fft.ifftshift(y,axes=0))), vmin=0, vmax=1)\n",
+    "\n",
+    "# plot zero-filled\n",
+    "ax[0].imshow(np.abs(zf), vmin=0, vmax=1)\n",
     "ax[0].set_title('Zero-filled')\n",
+    "\n",
+    "# plot POCS\n",
     "ax[1].imshow(est, vmin=0, vmax=1)\n",
     "ax[1].set_title('POCS recon')"
    ]
@@ -128,9 +252,9 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.7.3"
+   "version": "3.7.3-final"
   }
  },
  "nbformat": 4,
  "nbformat_minor": 4
-}
+}
\ No newline at end of file
diff --git a/talks/matlab_vs_python/partial_fourier/partial_fourier.m b/talks/matlab_vs_python/partial_fourier/partial_fourier.m
index b364ab7..81e1cc0 100644
--- a/talks/matlab_vs_python/partial_fourier/partial_fourier.m
+++ b/talks/matlab_vs_python/partial_fourier/partial_fourier.m
@@ -1,32 +1,84 @@
-%% Imports
-
 %% Load Data
+% get matfile object
 h = matfile('data.mat');
+
+% get img variable from the mat-file
 img = h.img;
 
 %% 6/8 Partial Fourier sampling
+% generate normally-distributed complex noise
 n = randn(96) + 1j*randn(96);
-y = fftshift(fft2(img),1) + 0*n;
-%y(73:end,:) = 0;
+
+% Fourier transform the image and add noise
+y = fftshift(fft2(img),1) + n;
+
+% set bottom 24/96 lines to 0
+y(73:end,:) = 0;
 
 %% Estimate phase
-pad = padarray(hann(48),24);
-phs = exp(1j*angle(ifft2(ifftshift(y.*pad,1))));
+% create zero-padded hanning filter for ky-filtering
+filt = padarray(hann(48),24);
+
+% generate low-res image with inverse Fourier transform
+low = ifft2(ifftshift(y.*filt,1));
+
+% get phase image
+phs = exp(1j*angle(low));
 
 %% POCS reconstruction
+% initialise image estimate to be zeros
 est = zeros(96);
+
+% set number of iterations
 iters = 10;
+
+% each iteration cycles between projections
 for i = 1:iters
-    est = fftshift(fft2(est.*phs),1);
+% projection onto data-consistent set:
+    % use a-priori phase to get complex image
+    est = est.*phs;
+
+    % Fourier transform to get k-space
+    est = fftshift(fft2(est), 1);
+
+    % replace data with measured lines
     est(1:72,:) = y(1:72,:);
-    est = max(real(ifft2(ifftshift(est,1)).*conj(phs)),0);   
+
+    % inverse Fourier transform to get back to image space
+    est = ifft2(ifftshift(est, 1));
+
+% projection onto non-negative reals
+    % remove a-priori phase
+    est = est.*conj(phs);
+
+    % get real part
+    est = real(est);
+
+    % ensure output is non-negative
+    est = max(est, 0);
 end
 
-%% Plot reconstruction
+%% Display error and plot reconstruction
+% compute zero-filled recon
+zf = ifft2(ifftshift(y, 1));
+
+% compute rmse for zero-filled and POCS recon
+err_zf = norm(zf(:) - img(:));
+err_pocs = norm(est(:).*phs(:) - img(:));
+
+% print errors
+fprintf(1, 'RMSE for zero-filled recon: %f\n', err_zf);
+fprintf(1, 'RMSE for POCS recon: %f\n', err_pocs);
+
+% plot both recons side-by-side
 figure();
+
+% plot zero-filled
 subplot(1,2,1);
-imshow(abs(ifft2(ifftshift(y,1))),[0 1],'colormap',jet);
+imshow(abs(zf), [0 1]);
 title('Zero-Filled');
+
+% plot POCS
 subplot(1,2,2);
-imshow(est,[0 1],'colormap',jet);
-title('POCS recon');
\ No newline at end of file
+imshow(est, [0 1]);
+title('POCS recon');
diff --git a/talks/matlab_vs_python/partial_fourier/partial_fourier.py b/talks/matlab_vs_python/partial_fourier/partial_fourier.py
deleted file mode 100644
index 2ceb7ec..0000000
--- a/talks/matlab_vs_python/partial_fourier/partial_fourier.py
+++ /dev/null
@@ -1,32 +0,0 @@
-#%% Imports
-import h5py
-import matplotlib.pyplot as plt
-import numpy as np
-
-#%% Load Data
-h = h5py.File('data.mat','r')
-img = np.transpose(h.get('img')['real']+1j*h.get('img')['imag'])
-
-#%% 6/8 Partial Fourier sampling
-n = np.random.randn(96,96) + 1j*np.random.randn(96,96)
-y = np.fft.fftshift(np.fft.fft2(img),axes=0) + n
-y[72:,:] = 0
-
-#%% Estimate phase
-pad = np.pad(np.hanning(48),24,'constant')[:,None]
-phs = np.exp(1j*np.angle(np.fft.ifft2(np.fft.ifftshift(y*pad,axes=0))))
-
-#%% POCS reconstruction
-est = np.zeros((96,96))
-iters = 10
-for i in range(iters):
-    est = np.fft.fftshift(np.fft.fft2(est*phs),axes=0)
-    est[:72,:] = y[:72,:]
-    est = np.maximum(np.real(np.fft.ifft2(np.fft.ifftshift(est,axes=0))*np.conj(phs)),0)
-
-#%% Plot reconstruction
-_, ax = plt.subplots(1,2)
-ax[0].imshow(np.abs(np.fft.ifft2(np.fft.ifftshift(y,axes=0))),vmin=0,vmax=1)
-ax[0].set_title('Zero-filled')
-ax[1].imshow(est, vmin=0, vmax=1)
-ax[1].set_title('POCS recon')
\ No newline at end of file
-- 
GitLab