Commit c131e3e7 authored by Paul McCarthy's avatar Paul McCarthy 🚵
Browse files

Merge branch 'master' into 'master'

Made MATLAB consistent with python, added plotting to partial fourier

See merge request fsl/pytreat-practicals-2020!27
parents 1c0fd89d a47d89cc
...@@ -18,7 +18,7 @@ subplot(2,1,2); ...@@ -18,7 +18,7 @@ subplot(2,1,2);
plot(pulseq(:,3)); plot(pulseq(:,3));
ylabel('Gradient'); ylabel('Gradient');
% Integrate ODE %% Integrate ODE
T1 = 1500; T1 = 1500;
T2 = 50; T2 = 50;
t0 = 0; t0 = 0;
...@@ -27,20 +27,20 @@ dt = 0.005; ...@@ -27,20 +27,20 @@ dt = 0.005;
M0 = [0; 0; 1]; M0 = [0; 0; 1];
[t, M] = ode45(@(t,M)bloch_ode(t, M, T1, T2), linspace(t0, t1, (t1-t0)/dt), M0); [t, M] = ode45(@(t,M)bloch_ode(t, M, T1, T2), linspace(t0, t1, (t1-t0)/dt), M0);
% Plot Results %% Plot Results
% create figure % create figure
figure();hold on; figure();hold on;
% plot x, y and z components of Magnetisation % plot x, y and z components of Magnetisation
plot(t, M(:,1)); plot(t, M(:,1), 'linewidth', 2);
plot(t, M(:,2)); plot(t, M(:,2), 'linewidth', 2);
plot(t, M(:,3)); plot(t, M(:,3), 'linewidth', 2);
% add legend and grid % add legend and grid
legend({'Mx','My','Mz'}); legend({'Mx','My','Mz'});
grid on; grid on;
% define the bloch equation %% define the bloch equation
function dM = bloch_ode(t, M, T1, T2) function dM = bloch_ode(t, M, T1, T2)
% get effective B-field at time t % get effective B-field at time t
B = B_eff(t); B = B_eff(t);
...@@ -51,7 +51,7 @@ function dM = bloch_ode(t, M, T1, T2) ...@@ -51,7 +51,7 @@ function dM = bloch_ode(t, M, T1, T2)
M(1)*B(2) - M(2)*B(1) - (M(3)-1)/T1]; M(1)*B(2) - M(2)*B(1) - (M(3)-1)/T1];
end end
% define effective B-field %% define effective B-field
function b = B_eff(t) function b = B_eff(t)
% Do nothing for 0.25 ms % Do nothing for 0.25 ms
if t < 0.25 if t < 0.25
......
...@@ -64,7 +64,8 @@ ...@@ -64,7 +64,8 @@
"In this section:\n", "In this section:\n",
"- np.random.randn\n", "- np.random.randn\n",
"- np.fft\n", "- np.fft\n",
"- 0-based indexing" "- 0-based indexing\n",
"- image plotting"
] ]
}, },
{ {
...@@ -80,7 +81,11 @@ ...@@ -80,7 +81,11 @@
"y = np.fft.fftshift(np.fft.fft2(img), axes=0) + n\n", "y = np.fft.fftshift(np.fft.fft2(img), axes=0) + n\n",
"\n", "\n",
"# set bottom 24/96 lines to 0\n", "# set bottom 24/96 lines to 0\n",
"y[72:,:] = 0" "y[72:,:] = 0\n",
"\n",
"# show sampling\n",
"_, ax = plt.subplots()\n",
"ax.imshow(np.log(np.abs(np.fft.fftshift(y, axes=1))))"
] ]
}, },
{ {
...@@ -96,7 +101,8 @@ ...@@ -96,7 +101,8 @@
"In this section:\n", "In this section:\n",
"- np.pad\n", "- np.pad\n",
"- np.hanning\n", "- np.hanning\n",
"- reshaping 1D array to 2D array using np.newaxis (or None)" "- reshaping 1D array to 2D array using np.newaxis (or None)\n",
"- subplots with titles"
] ]
}, },
{ {
...@@ -117,7 +123,14 @@ ...@@ -117,7 +123,14 @@
"low = np.fft.ifft2(np.fft.ifftshift(y*filt, axes=0))\n", "low = np.fft.ifft2(np.fft.ifftshift(y*filt, axes=0))\n",
"\n", "\n",
"# get phase image\n", "# get phase image\n",
"phs = np.exp(1j*np.angle(low))" "phs = np.exp(1j*np.angle(low))\n",
"\n",
"# show phase estimate alongside true phase\n",
"_, ax = plt.subplots(1,2)\n",
"ax[0].imshow(np.angle(img))\n",
"ax[0].set_title('True image phase')\n",
"ax[1].imshow(np.angle(phs))\n",
"ax[1].set_title('Estimated phase')"
] ]
}, },
{ {
...@@ -190,7 +203,7 @@ ...@@ -190,7 +203,7 @@
"\n", "\n",
"In this section:\n", "In this section:\n",
"- print formatted strings to standard output\n", "- print formatted strings to standard output\n",
"- plotting, with min/max scales\n", "- 2D subplots with min/max scales, figure size\n",
"- np.sum sums over all elements by default" "- np.sum sums over all elements by default"
] ]
}, },
...@@ -212,15 +225,17 @@ ...@@ -212,15 +225,17 @@
"print(f'RMSE for POCS recon: {err_pocs}')\n", "print(f'RMSE for POCS recon: {err_pocs}')\n",
"\n", "\n",
"# plot both recons side-by-side\n", "# plot both recons side-by-side\n",
"_, ax = plt.subplots(1,2,figsize=(16,16))\n", "_, ax = plt.subplots(2,2,figsize=(16,16))\n",
"\n", "\n",
"# plot zero-filled\n", "# plot zero-filled\n",
"ax[0].imshow(np.abs(zf), vmin=0, vmax=1)\n", "ax[0,0].imshow(np.abs(zf), vmin=0, vmax=1)\n",
"ax[0].set_title('Zero-filled')\n", "ax[0,0].set_title('Zero-filled')\n",
"ax[1,0].plot(np.abs(zf[:,47]))\n",
"\n", "\n",
"# plot POCS\n", "# plot POCS\n",
"ax[1].imshow(est, vmin=0, vmax=1)\n", "ax[0,1].imshow(est, vmin=0, vmax=1)\n",
"ax[1].set_title('POCS recon')" "ax[0,1].set_title('POCS recon')\n",
"ax[1,1].plot(np.abs(est[:,47]))"
] ]
} }
], ],
......
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Imports Imports
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import h5py import h5py
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Load data # Load data
Load complex image data from MATLAB mat-file (v7.3 or later), which is actually an HDF5 format Load complex image data from MATLAB mat-file (v7.3 or later), which is actually an HDF5 format
Complex data is loaded as a (real, imag) tuple, so it neds to be explicitly converted to complex double Complex data is loaded as a (real, imag) tuple, so it neds to be explicitly converted to complex double
In this section: In this section:
- using h5py module - using h5py module
- np.transpose - np.transpose
- 1j as imaginary constant - 1j as imaginary constant
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# get hdf5 object for the mat-file # get hdf5 object for the mat-file
h = h5py.File('data.mat','r') h = h5py.File('data.mat','r')
# get img variable from the mat-file # get img variable from the mat-file
dat = h.get('img') dat = h.get('img')
# turn array of (real, imag) tuples into an array of complex doubles # turn array of (real, imag) tuples into an array of complex doubles
# transpose to keep data in same orientation as MATLAB # transpose to keep data in same orientation as MATLAB
img = np.transpose(dat['real'] + 1j*dat['imag']) img = np.transpose(dat['real'] + 1j*dat['imag'])
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# 6/8 Partial Fourier sampling # 6/8 Partial Fourier sampling
Fourier transform the image to get k-space data, and add complex Gaussian noise Fourier transform the image to get k-space data, and add complex Gaussian noise
To simulate 6/8 Partial Fourier sampling, zero out the bottom 1/4 of k-space To simulate 6/8 Partial Fourier sampling, zero out the bottom 1/4 of k-space
In this section: In this section:
- np.random.randn - np.random.randn
- np.fft - np.fft
- 0-based indexing - 0-based indexing
- image plotting
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# generate normally-distributed complex noise # generate normally-distributed complex noise
n = np.random.randn(96,96) + 1j*np.random.randn(96,96) n = np.random.randn(96,96) + 1j*np.random.randn(96,96)
# Fourier transform the image and add noise # Fourier transform the image and add noise
y = np.fft.fftshift(np.fft.fft2(img), axes=0) + n y = np.fft.fftshift(np.fft.fft2(img), axes=0) + n
# set bottom 24/96 lines to 0 # set bottom 24/96 lines to 0
y[72:,:] = 0 y[72:,:] = 0
# show sampling
_, ax = plt.subplots()
ax.imshow(np.log(np.abs(np.fft.fftshift(y, axes=1))))
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Estimate phase # Estimate phase
Filter the k-space data and extract a low-resolution phase estimate Filter the k-space data and extract a low-resolution phase estimate
Filtering can help reduce ringing in the phase image Filtering can help reduce ringing in the phase image
In this section: In this section:
- np.pad - np.pad
- np.hanning - np.hanning
- reshaping 1D array to 2D array using np.newaxis (or None) - reshaping 1D array to 2D array using np.newaxis (or None)
- subplots with titles
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# create zero-padded hanning filter for ky-filtering # create zero-padded hanning filter for ky-filtering
filt = np.pad(np.hanning(48),24,'constant') filt = np.pad(np.hanning(48),24,'constant')
# reshape 1D array into 2D array # reshape 1D array into 2D array
filt = filt[:,np.newaxis] filt = filt[:,np.newaxis]
# or # or
# filt = filt[:,None] # filt = filt[:,None]
# generate low-res image with inverse Fourier transform # generate low-res image with inverse Fourier transform
low = np.fft.ifft2(np.fft.ifftshift(y*filt, axes=0)) low = np.fft.ifft2(np.fft.ifftshift(y*filt, axes=0))
# get phase image # get phase image
phs = np.exp(1j*np.angle(low)) phs = np.exp(1j*np.angle(low))
# show phase estimate alongside true phase
_, ax = plt.subplots(1,2)
ax[0].imshow(np.angle(img))
ax[0].set_title('True image phase')
ax[1].imshow(np.angle(phs))
ax[1].set_title('Estimated phase')
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# POCS reconstruction # POCS reconstruction
Perform the projection-onto-convex-sets (POCS) partial Fourier reconstruction method. Perform the projection-onto-convex-sets (POCS) partial Fourier reconstruction method.
POCS is an iterative scheme estimates the reconstructed image as any element in the intersection of the following two (convex) sets: POCS is an iterative scheme estimates the reconstructed image as any element in the intersection of the following two (convex) sets:
1. Set of images consistent with the measured data 1. Set of images consistent with the measured data
2. Set of images that are non-negative real 2. Set of images that are non-negative real
This requires prior knowledge of the image phase (hence the estimate above), and it works because although we have less than a full k-space of measurements, we're now only estimating half the number of free parameters (real values only, instead of real + imag), and we're no longer under-determined. Equivalently, consider the fact that real-valued images have conjugate symmetric k-spaces, so we only require half of k-space to reconstruct our image. This requires prior knowledge of the image phase (hence the estimate above), and it works because although we have less than a full k-space of measurements, we're now only estimating half the number of free parameters (real values only, instead of real + imag), and we're no longer under-determined. Equivalently, consider the fact that real-valued images have conjugate symmetric k-spaces, so we only require half of k-space to reconstruct our image.
In this section: In this section:
- np.zeros - np.zeros
- range() builtin - range() builtin
- point-wise multiplication (*) - point-wise multiplication (*)
- np.fft operations default to last axis, not first - np.fft operations default to last axis, not first
- np.maximum vs np.max - np.maximum vs np.max
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# initialise image estimate to be zeros # initialise image estimate to be zeros
est = np.zeros((96,96)) est = np.zeros((96,96))
# set the number of iterations # set the number of iterations
iters = 10 iters = 10
# each iteration cycles between projections # each iteration cycles between projections
for i in range(iters): for i in range(iters):
# projection onto data-consistent set: # projection onto data-consistent set:
# use a-priori phase to get complex image # use a-priori phase to get complex image
est = est*phs est = est*phs
# Fourier transform to get k-space # Fourier transform to get k-space
est = np.fft.fftshift(np.fft.fft2(est), axes=0) est = np.fft.fftshift(np.fft.fft2(est), axes=0)
# replace data with measured lines # replace data with measured lines
est[:72,:] = y[:72,:] est[:72,:] = y[:72,:]
# inverse Fourier transform to get back to image space # inverse Fourier transform to get back to image space
est = np.fft.ifft2(np.fft.ifftshift(est, axes=0)) est = np.fft.ifft2(np.fft.ifftshift(est, axes=0))
# projection onto non-negative reals: # projection onto non-negative reals:
# remove a-priori phase # remove a-priori phase
est = est*np.conj(phs) est = est*np.conj(phs)
# get real part # get real part
est = np.real(est) est = np.real(est)
# ensure output is non-negative # ensure output is non-negative
est = np.maximum(est, 0) est = np.maximum(est, 0)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Display error and plot reconstruction # Display error and plot reconstruction
The POCS reconstruction is compared to a zero-filled reconstruction (i.e., where the missing data is zeroed prior to inverse Fourier Transform) The POCS reconstruction is compared to a zero-filled reconstruction (i.e., where the missing data is zeroed prior to inverse Fourier Transform)
In this section: In this section:
- print formatted strings to standard output - print formatted strings to standard output
- plotting, with min/max scales - 2D subplots with min/max scales, figure size
- np.sum sums over all elements by default - np.sum sums over all elements by default
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# compute zero-filled recon # compute zero-filled recon
zf = np.fft.ifft2(np.fft.ifftshift(y, axes=0)) zf = np.fft.ifft2(np.fft.ifftshift(y, axes=0))
# compute rmse for zero-filled and POCS recon # compute rmse for zero-filled and POCS recon
err_zf = np.sqrt(np.sum(np.abs(zf - img)**2)) err_zf = np.sqrt(np.sum(np.abs(zf - img)**2))
err_pocs = np.sqrt(np.sum(np.abs(est*phs - img)**2)) err_pocs = np.sqrt(np.sum(np.abs(est*phs - img)**2))
# print errors # print errors
print(f'RMSE for zero-filled recon: {err_zf}') print(f'RMSE for zero-filled recon: {err_zf}')
print(f'RMSE for POCS recon: {err_pocs}') print(f'RMSE for POCS recon: {err_pocs}')
# plot both recons side-by-side # plot both recons side-by-side
_, ax = plt.subplots(1,2,figsize=(16,16)) _, ax = plt.subplots(2,2,figsize=(16,16))
# plot zero-filled # plot zero-filled
ax[0].imshow(np.abs(zf), vmin=0, vmax=1) ax[0,0].imshow(np.abs(zf), vmin=0, vmax=1)
ax[0].set_title('Zero-filled') ax[0,0].set_title('Zero-filled')
ax[1,0].plot(np.abs(zf[:,47]))
# plot POCS # plot POCS
ax[1].imshow(est, vmin=0, vmax=1) ax[0,1].imshow(est, vmin=0, vmax=1)
ax[1].set_title('POCS recon') ax[0,1].set_title('POCS recon')
ax[1,1].plot(np.abs(est[:,47]))
``` ```
......
...@@ -15,6 +15,10 @@ y = fftshift(fft2(img),1) + n; ...@@ -15,6 +15,10 @@ y = fftshift(fft2(img),1) + n;
% set bottom 24/96 lines to 0 % set bottom 24/96 lines to 0
y(73:end,:) = 0; y(73:end,:) = 0;
% show sampling
figure();
imshow(log(abs(fftshift(y,2))), [], 'colormap', jet)
%% Estimate phase %% Estimate phase
% create zero-padded hanning filter for ky-filtering % create zero-padded hanning filter for ky-filtering
filt = padarray(hann(48),24); filt = padarray(hann(48),24);
...@@ -25,6 +29,16 @@ low = ifft2(ifftshift(y.*filt,1)); ...@@ -25,6 +29,16 @@ low = ifft2(ifftshift(y.*filt,1));
% get phase image % get phase image
phs = exp(1j*angle(low)); phs = exp(1j*angle(low));
% show phase estimate alongside true phase
figure();
subplot(1,2,1);
imshow(angle(img), [-pi,pi], 'colormap', hsv);
title('True image phase');
subplot(1,2,2);
imshow(angle(phs), [-pi,pi], 'colormap', hsv)
title('Estimated phase');
%% POCS reconstruction %% POCS reconstruction
% initialise image estimate to be zeros % initialise image estimate to be zeros
est = zeros(96); est = zeros(96);
...@@ -74,11 +88,15 @@ fprintf(1, 'RMSE for POCS recon: %f\n', err_pocs); ...@@ -74,11 +88,15 @@ fprintf(1, 'RMSE for POCS recon: %f\n', err_pocs);
figure(); figure();
% plot zero-filled % plot zero-filled
subplot(1,2,1); subplot(2,2,1);
imshow(abs(zf), [0 1]); imshow(abs(zf), [0 1]);
title('Zero-Filled'); title('Zero-Filled');
subplot(2,2,3);
plot(abs(zf(:,48)), 'linewidth', 2);
% plot POCS % plot POCS
subplot(1,2,2); subplot(2,2,2);
imshow(est, [0 1]); imshow(est, [0 1]);
title('POCS recon'); title('POCS recon');
subplot(2,2,4);
plot(abs(est(:,48)), 'linewidth', 2);
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment