Skip to content
Snippets Groups Projects
partial_fourier.ipynb 7.28 KiB

Imports

import h5py
import matplotlib.pyplot as plt
import numpy as np

Load data

Load complex image data from MATLAB mat-file (v7.3 or later), which is actually an HDF5 format

Complex data is loaded as a (real, imag) tuple, so it neds to be explicitly converted to complex double

In this section:

  • using h5py module
  • np.transpose
  • 1j as imaginary constant
# get hdf5 object for the mat-file
h = h5py.File('data.mat','r')

# get img variable from the mat-file
dat = h.get('img')

# turn array of (real, imag) tuples into an array of complex doubles
# transpose to keep data in same orientation as MATLAB
img = np.transpose(dat['real'] + 1j*dat['imag'])

6/8 Partial Fourier sampling

Fourier transform the image to get k-space data, and add complex Gaussian noise

To simulate 6/8 Partial Fourier sampling, zero out the bottom 1/4 of k-space

In this section:

  • np.random.randn
  • np.fft
  • 0-based indexing
# generate normally-distributed complex noise
n = np.random.randn(96,96) + 1j*np.random.randn(96,96)

# Fourier transform the image and add noise
y = np.fft.fftshift(np.fft.fft2(img), axes=0) + n

# set bottom 24/96 lines to 0
y[72:,:] = 0

Estimate phase

Filter the k-space data and extract a low-resolution phase estimate

Filtering can help reduce ringing in the phase image

In this section:

  • np.pad
  • np.hanning
  • reshaping 1D array to 2D array using np.newaxis (or None)
# create zero-padded hanning filter for ky-filtering
filt = np.pad(np.hanning(48),24,'constant')

# reshape 1D array into 2D array
filt = filt[:,np.newaxis]
# or
# filt = filt[:,None]

# generate low-res image with inverse Fourier transform
low = np.fft.ifft2(np.fft.ifftshift(y*filt, axes=0))

# get phase image
phs = np.exp(1j*np.angle(low))

POCS reconstruction

Perform the projection-onto-convex-sets (POCS) partial Fourier reconstruction method.

POCS is an iterative scheme estimates the reconstructed image as any element in the intersection of the following two (convex) sets:

  1. Set of images consistent with the measured data
  2. Set of images that are non-negative real

This requires prior knowledge of the image phase (hence the estimate above), and it works because although we have less than a full k-space of measurements, we're now only estimating half the number of free parameters (real values only, instead of real + imag), and we're no longer under-determined. Equivalently, consider the fact that real-valued images have conjugate symmetric k-spaces, so we only require half of k-space to reconstruct our image.

In this section:

  • np.zeros
  • range() builtin
  • point-wise multiplication (*)
  • np.fft operations default to last axis, not first
  • np.maximum vs np.max
# initialise image estimate to be zeros
est = np.zeros((96,96))

# set the number of iterations 
iters = 10

# each iteration cycles between projections
for i in range(iters):
# projection onto data-consistent set:
    # use a-priori phase to get complex image
    est = est*phs
    
    # Fourier transform to get k-space
    est = np.fft.fftshift(np.fft.fft2(est), axes=0)
    
    # replace data with measured lines
    est[:72,:] = y[:72,:]
    
    # inverse Fourier transform to get back to image space
    est = np.fft.ifft2(np.fft.ifftshift(est, axes=0))

# projection onto non-negative reals:
    # remove a-priori phase
    est = est*np.conj(phs)
    
    # get real part
    est = np.real(est)

    # ensure output is non-negative
    est = np.maximum(est, 0)

Display error and plot reconstruction

The POCS reconstruction is compared to a zero-filled reconstruction (i.e., where the missing data is zeroed prior to inverse Fourier Transform)

In this section:

  • print formatted strings to standard output
  • plotting, with min/max scales
  • np.sum sums over all elements by default
# compute zero-filled recon
zf = np.fft.ifft2(np.fft.ifftshift(y, axes=0))

# compute rmse for zero-filled and POCS recon
err_zf = np.sqrt(np.sum(np.abs(zf - img)**2))
err_pocs = np.sqrt(np.sum(np.abs(est*phs - img)**2))

# print errors
print(f'RMSE for zero-filled recon: {err_zf}')
print(f'RMSE for POCS recon: {err_pocs}')

# plot both recons side-by-side
_, ax = plt.subplots(1,2,figsize=(16,16))

# plot zero-filled
ax[0].imshow(np.abs(zf), vmin=0, vmax=1)
ax[0].set_title('Zero-filled')

# plot POCS
ax[1].imshow(est, vmin=0, vmax=1)
ax[1].set_title('POCS recon')