import h5py
import matplotlib.pyplot as plt
import numpy as np
-
Mark Chiew authoredMark Chiew authored
Imports
Load data
Load complex image data from MATLAB mat-file (v7.3 or later), which is actually an HDF5 format
Complex data is loaded as a (real, imag) tuple, so it neds to be explicitly converted to complex double
In this section:
- using h5py module
- np.transpose
- 1j as imaginary constant
# get hdf5 object for the mat-file
h = h5py.File('data.mat','r')
# get img variable from the mat-file
dat = h.get('img')
# turn array of (real, imag) tuples into an array of complex doubles
# transpose to keep data in same orientation as MATLAB
img = np.transpose(dat['real'] + 1j*dat['imag'])
6/8 Partial Fourier sampling
Fourier transform the image to get k-space data, and add complex Gaussian noise
To simulate 6/8 Partial Fourier sampling, zero out the bottom 1/4 of k-space
In this section:
- np.random.randn
- np.fft
- 0-based indexing
# generate normally-distributed complex noise
n = np.random.randn(96,96) + 1j*np.random.randn(96,96)
# Fourier transform the image and add noise
y = np.fft.fftshift(np.fft.fft2(img), axes=0) + n
# set bottom 24/96 lines to 0
y[72:,:] = 0
Estimate phase
Filter the k-space data and extract a low-resolution phase estimate
Filtering can help reduce ringing in the phase image
In this section:
- np.pad
- np.hanning
- reshaping 1D array to 2D array using np.newaxis (or None)
# create zero-padded hanning filter for ky-filtering
filt = np.pad(np.hanning(48),24,'constant')
# reshape 1D array into 2D array
filt = filt[:,np.newaxis]
# or
# filt = filt[:,None]
# generate low-res image with inverse Fourier transform
low = np.fft.ifft2(np.fft.ifftshift(y*filt, axes=0))
# get phase image
phs = np.exp(1j*np.angle(low))
POCS reconstruction
Perform the projection-onto-convex-sets (POCS) partial Fourier reconstruction method.
POCS is an iterative scheme estimates the reconstructed image as any element in the intersection of the following two (convex) sets:
- Set of images consistent with the measured data
- Set of images that are non-negative real
This requires prior knowledge of the image phase (hence the estimate above), and it works because although we have less than a full k-space of measurements, we're now only estimating half the number of free parameters (real values only, instead of real + imag), and we're no longer under-determined. Equivalently, consider the fact that real-valued images have conjugate symmetric k-spaces, so we only require half of k-space to reconstruct our image.
In this section:
- np.zeros
- range() builtin
- point-wise multiplication (*)
- np.fft operations default to last axis, not first
- np.maximum vs np.max
# initialise image estimate to be zeros
est = np.zeros((96,96))
# set the number of iterations
iters = 10
# each iteration cycles between projections
for i in range(iters):
# projection onto data-consistent set:
# use a-priori phase to get complex image
est = est*phs
# Fourier transform to get k-space
est = np.fft.fftshift(np.fft.fft2(est), axes=0)
# replace data with measured lines
est[:72,:] = y[:72,:]
# inverse Fourier transform to get back to image space
est = np.fft.ifft2(np.fft.ifftshift(est, axes=0))
# projection onto non-negative reals:
# remove a-priori phase
est = est*np.conj(phs)
# get real part
est = np.real(est)
# ensure output is non-negative
est = np.maximum(est, 0)
Display error and plot reconstruction
The POCS reconstruction is compared to a zero-filled reconstruction (i.e., where the missing data is zeroed prior to inverse Fourier Transform)
In this section:
- print formatted strings to standard output
- plotting, with min/max scales
- np.sum sums over all elements by default
# compute zero-filled recon
zf = np.fft.ifft2(np.fft.ifftshift(y, axes=0))
# compute rmse for zero-filled and POCS recon
err_zf = np.sqrt(np.sum(np.abs(zf - img)**2))
err_pocs = np.sqrt(np.sum(np.abs(est*phs - img)**2))
# print errors
print(f'RMSE for zero-filled recon: {err_zf}')
print(f'RMSE for POCS recon: {err_pocs}')
# plot both recons side-by-side
_, ax = plt.subplots(1,2,figsize=(16,16))
# plot zero-filled
ax[0].imshow(np.abs(zf), vmin=0, vmax=1)
ax[0].set_title('Zero-filled')
# plot POCS
ax[1].imshow(est, vmin=0, vmax=1)
ax[1].set_title('POCS recon')