Skip to content
Snippets Groups Projects
Commit c131e3e7 authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

Merge branch 'master' into 'master'

Made MATLAB consistent with python, added plotting to partial fourier

See merge request !27
parents 1c0fd89d a47d89cc
No related branches found
No related tags found
No related merge requests found
...@@ -18,7 +18,7 @@ subplot(2,1,2); ...@@ -18,7 +18,7 @@ subplot(2,1,2);
plot(pulseq(:,3)); plot(pulseq(:,3));
ylabel('Gradient'); ylabel('Gradient');
% Integrate ODE %% Integrate ODE
T1 = 1500; T1 = 1500;
T2 = 50; T2 = 50;
t0 = 0; t0 = 0;
...@@ -27,20 +27,20 @@ dt = 0.005; ...@@ -27,20 +27,20 @@ dt = 0.005;
M0 = [0; 0; 1]; M0 = [0; 0; 1];
[t, M] = ode45(@(t,M)bloch_ode(t, M, T1, T2), linspace(t0, t1, (t1-t0)/dt), M0); [t, M] = ode45(@(t,M)bloch_ode(t, M, T1, T2), linspace(t0, t1, (t1-t0)/dt), M0);
% Plot Results %% Plot Results
% create figure % create figure
figure();hold on; figure();hold on;
% plot x, y and z components of Magnetisation % plot x, y and z components of Magnetisation
plot(t, M(:,1)); plot(t, M(:,1), 'linewidth', 2);
plot(t, M(:,2)); plot(t, M(:,2), 'linewidth', 2);
plot(t, M(:,3)); plot(t, M(:,3), 'linewidth', 2);
% add legend and grid % add legend and grid
legend({'Mx','My','Mz'}); legend({'Mx','My','Mz'});
grid on; grid on;
% define the bloch equation %% define the bloch equation
function dM = bloch_ode(t, M, T1, T2) function dM = bloch_ode(t, M, T1, T2)
% get effective B-field at time t % get effective B-field at time t
B = B_eff(t); B = B_eff(t);
...@@ -51,7 +51,7 @@ function dM = bloch_ode(t, M, T1, T2) ...@@ -51,7 +51,7 @@ function dM = bloch_ode(t, M, T1, T2)
M(1)*B(2) - M(2)*B(1) - (M(3)-1)/T1]; M(1)*B(2) - M(2)*B(1) - (M(3)-1)/T1];
end end
% define effective B-field %% define effective B-field
function b = B_eff(t) function b = B_eff(t)
% Do nothing for 0.25 ms % Do nothing for 0.25 ms
if t < 0.25 if t < 0.25
......
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Imports Imports
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import h5py import h5py
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Load data # Load data
Load complex image data from MATLAB mat-file (v7.3 or later), which is actually an HDF5 format Load complex image data from MATLAB mat-file (v7.3 or later), which is actually an HDF5 format
Complex data is loaded as a (real, imag) tuple, so it neds to be explicitly converted to complex double Complex data is loaded as a (real, imag) tuple, so it neds to be explicitly converted to complex double
In this section: In this section:
- using h5py module - using h5py module
- np.transpose - np.transpose
- 1j as imaginary constant - 1j as imaginary constant
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# get hdf5 object for the mat-file # get hdf5 object for the mat-file
h = h5py.File('data.mat','r') h = h5py.File('data.mat','r')
# get img variable from the mat-file # get img variable from the mat-file
dat = h.get('img') dat = h.get('img')
# turn array of (real, imag) tuples into an array of complex doubles # turn array of (real, imag) tuples into an array of complex doubles
# transpose to keep data in same orientation as MATLAB # transpose to keep data in same orientation as MATLAB
img = np.transpose(dat['real'] + 1j*dat['imag']) img = np.transpose(dat['real'] + 1j*dat['imag'])
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# 6/8 Partial Fourier sampling # 6/8 Partial Fourier sampling
Fourier transform the image to get k-space data, and add complex Gaussian noise Fourier transform the image to get k-space data, and add complex Gaussian noise
To simulate 6/8 Partial Fourier sampling, zero out the bottom 1/4 of k-space To simulate 6/8 Partial Fourier sampling, zero out the bottom 1/4 of k-space
In this section: In this section:
- np.random.randn - np.random.randn
- np.fft - np.fft
- 0-based indexing - 0-based indexing
- image plotting
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# generate normally-distributed complex noise # generate normally-distributed complex noise
n = np.random.randn(96,96) + 1j*np.random.randn(96,96) n = np.random.randn(96,96) + 1j*np.random.randn(96,96)
# Fourier transform the image and add noise # Fourier transform the image and add noise
y = np.fft.fftshift(np.fft.fft2(img), axes=0) + n y = np.fft.fftshift(np.fft.fft2(img), axes=0) + n
# set bottom 24/96 lines to 0 # set bottom 24/96 lines to 0
y[72:,:] = 0 y[72:,:] = 0
# show sampling
_, ax = plt.subplots()
ax.imshow(np.log(np.abs(np.fft.fftshift(y, axes=1))))
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Estimate phase # Estimate phase
Filter the k-space data and extract a low-resolution phase estimate Filter the k-space data and extract a low-resolution phase estimate
Filtering can help reduce ringing in the phase image Filtering can help reduce ringing in the phase image
In this section: In this section:
- np.pad - np.pad
- np.hanning - np.hanning
- reshaping 1D array to 2D array using np.newaxis (or None) - reshaping 1D array to 2D array using np.newaxis (or None)
- subplots with titles
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# create zero-padded hanning filter for ky-filtering # create zero-padded hanning filter for ky-filtering
filt = np.pad(np.hanning(48),24,'constant') filt = np.pad(np.hanning(48),24,'constant')
# reshape 1D array into 2D array # reshape 1D array into 2D array
filt = filt[:,np.newaxis] filt = filt[:,np.newaxis]
# or # or
# filt = filt[:,None] # filt = filt[:,None]
# generate low-res image with inverse Fourier transform # generate low-res image with inverse Fourier transform
low = np.fft.ifft2(np.fft.ifftshift(y*filt, axes=0)) low = np.fft.ifft2(np.fft.ifftshift(y*filt, axes=0))
# get phase image # get phase image
phs = np.exp(1j*np.angle(low)) phs = np.exp(1j*np.angle(low))
# show phase estimate alongside true phase
_, ax = plt.subplots(1,2)
ax[0].imshow(np.angle(img))
ax[0].set_title('True image phase')
ax[1].imshow(np.angle(phs))
ax[1].set_title('Estimated phase')
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# POCS reconstruction # POCS reconstruction
Perform the projection-onto-convex-sets (POCS) partial Fourier reconstruction method. Perform the projection-onto-convex-sets (POCS) partial Fourier reconstruction method.
POCS is an iterative scheme estimates the reconstructed image as any element in the intersection of the following two (convex) sets: POCS is an iterative scheme estimates the reconstructed image as any element in the intersection of the following two (convex) sets:
1. Set of images consistent with the measured data 1. Set of images consistent with the measured data
2. Set of images that are non-negative real 2. Set of images that are non-negative real
This requires prior knowledge of the image phase (hence the estimate above), and it works because although we have less than a full k-space of measurements, we're now only estimating half the number of free parameters (real values only, instead of real + imag), and we're no longer under-determined. Equivalently, consider the fact that real-valued images have conjugate symmetric k-spaces, so we only require half of k-space to reconstruct our image. This requires prior knowledge of the image phase (hence the estimate above), and it works because although we have less than a full k-space of measurements, we're now only estimating half the number of free parameters (real values only, instead of real + imag), and we're no longer under-determined. Equivalently, consider the fact that real-valued images have conjugate symmetric k-spaces, so we only require half of k-space to reconstruct our image.
In this section: In this section:
- np.zeros - np.zeros
- range() builtin - range() builtin
- point-wise multiplication (*) - point-wise multiplication (*)
- np.fft operations default to last axis, not first - np.fft operations default to last axis, not first
- np.maximum vs np.max - np.maximum vs np.max
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# initialise image estimate to be zeros # initialise image estimate to be zeros
est = np.zeros((96,96)) est = np.zeros((96,96))
# set the number of iterations # set the number of iterations
iters = 10 iters = 10
# each iteration cycles between projections # each iteration cycles between projections
for i in range(iters): for i in range(iters):
# projection onto data-consistent set: # projection onto data-consistent set:
# use a-priori phase to get complex image # use a-priori phase to get complex image
est = est*phs est = est*phs
# Fourier transform to get k-space # Fourier transform to get k-space
est = np.fft.fftshift(np.fft.fft2(est), axes=0) est = np.fft.fftshift(np.fft.fft2(est), axes=0)
# replace data with measured lines # replace data with measured lines
est[:72,:] = y[:72,:] est[:72,:] = y[:72,:]
# inverse Fourier transform to get back to image space # inverse Fourier transform to get back to image space
est = np.fft.ifft2(np.fft.ifftshift(est, axes=0)) est = np.fft.ifft2(np.fft.ifftshift(est, axes=0))
# projection onto non-negative reals: # projection onto non-negative reals:
# remove a-priori phase # remove a-priori phase
est = est*np.conj(phs) est = est*np.conj(phs)
# get real part # get real part
est = np.real(est) est = np.real(est)
# ensure output is non-negative # ensure output is non-negative
est = np.maximum(est, 0) est = np.maximum(est, 0)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Display error and plot reconstruction # Display error and plot reconstruction
The POCS reconstruction is compared to a zero-filled reconstruction (i.e., where the missing data is zeroed prior to inverse Fourier Transform) The POCS reconstruction is compared to a zero-filled reconstruction (i.e., where the missing data is zeroed prior to inverse Fourier Transform)
In this section: In this section:
- print formatted strings to standard output - print formatted strings to standard output
- plotting, with min/max scales - 2D subplots with min/max scales, figure size
- np.sum sums over all elements by default - np.sum sums over all elements by default
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# compute zero-filled recon # compute zero-filled recon
zf = np.fft.ifft2(np.fft.ifftshift(y, axes=0)) zf = np.fft.ifft2(np.fft.ifftshift(y, axes=0))
# compute rmse for zero-filled and POCS recon # compute rmse for zero-filled and POCS recon
err_zf = np.sqrt(np.sum(np.abs(zf - img)**2)) err_zf = np.sqrt(np.sum(np.abs(zf - img)**2))
err_pocs = np.sqrt(np.sum(np.abs(est*phs - img)**2)) err_pocs = np.sqrt(np.sum(np.abs(est*phs - img)**2))
# print errors # print errors
print(f'RMSE for zero-filled recon: {err_zf}') print(f'RMSE for zero-filled recon: {err_zf}')
print(f'RMSE for POCS recon: {err_pocs}') print(f'RMSE for POCS recon: {err_pocs}')
# plot both recons side-by-side # plot both recons side-by-side
_, ax = plt.subplots(1,2,figsize=(16,16)) _, ax = plt.subplots(2,2,figsize=(16,16))
# plot zero-filled # plot zero-filled
ax[0].imshow(np.abs(zf), vmin=0, vmax=1) ax[0,0].imshow(np.abs(zf), vmin=0, vmax=1)
ax[0].set_title('Zero-filled') ax[0,0].set_title('Zero-filled')
ax[1,0].plot(np.abs(zf[:,47]))
# plot POCS # plot POCS
ax[1].imshow(est, vmin=0, vmax=1) ax[0,1].imshow(est, vmin=0, vmax=1)
ax[1].set_title('POCS recon') ax[0,1].set_title('POCS recon')
ax[1,1].plot(np.abs(est[:,47]))
``` ```
......
...@@ -15,6 +15,10 @@ y = fftshift(fft2(img),1) + n; ...@@ -15,6 +15,10 @@ y = fftshift(fft2(img),1) + n;
% set bottom 24/96 lines to 0 % set bottom 24/96 lines to 0
y(73:end,:) = 0; y(73:end,:) = 0;
% show sampling
figure();
imshow(log(abs(fftshift(y,2))), [], 'colormap', jet)
%% Estimate phase %% Estimate phase
% create zero-padded hanning filter for ky-filtering % create zero-padded hanning filter for ky-filtering
filt = padarray(hann(48),24); filt = padarray(hann(48),24);
...@@ -25,6 +29,16 @@ low = ifft2(ifftshift(y.*filt,1)); ...@@ -25,6 +29,16 @@ low = ifft2(ifftshift(y.*filt,1));
% get phase image % get phase image
phs = exp(1j*angle(low)); phs = exp(1j*angle(low));
% show phase estimate alongside true phase
figure();
subplot(1,2,1);
imshow(angle(img), [-pi,pi], 'colormap', hsv);
title('True image phase');
subplot(1,2,2);
imshow(angle(phs), [-pi,pi], 'colormap', hsv)
title('Estimated phase');
%% POCS reconstruction %% POCS reconstruction
% initialise image estimate to be zeros % initialise image estimate to be zeros
est = zeros(96); est = zeros(96);
...@@ -74,11 +88,15 @@ fprintf(1, 'RMSE for POCS recon: %f\n', err_pocs); ...@@ -74,11 +88,15 @@ fprintf(1, 'RMSE for POCS recon: %f\n', err_pocs);
figure(); figure();
% plot zero-filled % plot zero-filled
subplot(1,2,1); subplot(2,2,1);
imshow(abs(zf), [0 1]); imshow(abs(zf), [0 1]);
title('Zero-Filled'); title('Zero-Filled');
subplot(2,2,3);
plot(abs(zf(:,48)), 'linewidth', 2);
% plot POCS % plot POCS
subplot(1,2,2); subplot(2,2,2);
imshow(est, [0 1]); imshow(est, [0 1]);
title('POCS recon'); title('POCS recon');
subplot(2,2,4);
plot(abs(est(:,48)), 'linewidth', 2);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment