Commit 47a21f26 authored by ihuszar's avatar ihuszar
Browse files

Slice-to-volume protocol for BigMac dataset.

parent 7a44410b
#!/usr/bin/env python
import tirl
import numpy as np
from tirl.costs.mind import CostMIND
from skimage.segmentation import slic
from skimage.future import graph
from skimage.exposure import rescale_intensity
import matplotlib
matplotlib.use("macosx")
import matplotlib.pyplot as plt
from skimage.measure import regionprops
N_POINTS = 16
MASK_UTHR = 1
MASK_LTHR = 0.1
img = tirl.load("/Users/inhuszar/bigmac/reg/stage2.timg")
vol = tirl.load("/Users/inhuszar/bigmac/reg/vol.timg")
mri = vol.evaluate(img.domain)
mask = np.logical_and(img.data > MASK_LTHR, img.data <= MASK_UTHR)
cost = CostMIND(mri, img, precompute=False, normalise=True)
cm = cost.costmap().reduce_tensors(copy=True)
im = rescale_intensity(cm.data, out_range=np.uint8).astype(np.float64) \
* mask.astype(np.int64)
test_cases = (1,) #2 ** np.arange(2, 8)
shape_measure = []
distance_measure = []
for n_segments in test_cases:
# Segment cost function map given the number of segments
im = rescale_intensity(cm.data, out_range=np.uint8).astype(np.float64) \
* mask.astype(np.int64)
segments = slic(im, n_segments=150, sigma=5, compactness=50) \
* mask.astype(np.int64)
g = graph.rag_mean_color(cm.data, segments)
segments = graph.cut_normalized(segments, g, thresh=0.1, in_place=False)
N = np.unique(segments).size - 1
print(N)
regions = regionprops(segments, cm.data)
regions = sorted(regions, key=lambda r: np.std(r.intensity_image))[::-1]
# Measure the average shape parameter of the segments
convexity = np.mean([r.equivalent_diameter for r in regions])
shape_measure.append((N, convexity))
# Measure the average distance between segment centroids
centroids = [r.centroid for r in regions]
xi = np.stack(centroids, axis=0)
edges = np.max(xi, axis=0) - np.min(xi, axis=0)
edges = edges[np.nonzero(edges)]
epsilon = np.power(np.prod(edges) / len(regions), 1. / edges.size)
distance_measure.append((N, epsilon))
# Create new figure for current segment map
mapfig, mapax = plt.subplots(1, 1)
label = np.zeros_like(segments)
im = mapax.imshow(label, vmin=segments.min(), vmax=segments.max(), zorder=1)
for region in regions:
label[segments == region.label] = region.label
im.set_data(label)
y, x = region.centroid
mapax.plot(x, y, "ro", zorder=2)
mapfig.canvas.draw_idle()
plt.pause(0.25)
else:
pass
plt.pause(1)
else:
# Plot the measured properties versus the number of segments
measfig, measax = plt.subplots(1, 1)
measax.plot(*tuple(zip(*shape_measure)), "b-")
measax.plot(*tuple(zip(*distance_measure)), "r-")
plt.show()
......@@ -5,14 +5,14 @@
"author": "Istvan N Huszar, Amy FD Howard"
},
"slice": {
"file": "/Users/inhuszar/bigmac/histo/H092x/mosaic_colour_mrires.tif",
"alternative": "/Users/inhuszar/bigmac/reg/stage4.timg",
"file": "/Users/inhuszar/bigmac/histo/H092x/mosaic_colour.tif",
"alternative": null,
"storage": "mem",
"mask": {
"file": null,
"normalise": false
},
"resolution": 0.3,
"resolution": 0.0045,
"preview": false,
"actions": ["resample_image", "lab_b"]
},
......@@ -30,14 +30,13 @@
},
"general": {
"verbosity": "debug",
"system": "macos",
"system": "macosx",
"logfile": "/Users/inhuszar/bigmac/reg/tirl_slice_to_volume.log",
"outdir": "/Users/inhuszar/bigmac/reg",
"stages": [5, 2, 3, 4],
"stage_index": 8,
"stages": [1, 2, 3, 4, 5, 2, 3, 4],
"stage_index": 3,
"snapshot_ext": "png",
"randomseed": 1,
"warnings": false
"warnings": false
},
"regparams": {
"init": {
......@@ -56,16 +55,22 @@
"visualise": true,
"verbose": 4,
"slab": {
"centre": [0, -4.3, 0],
"centre": [0, -4, 0],
"normal": [0, 1, 0],
"thickness": 1.5,
"inits": 5
"thickness": 3,
"inits": 7
},
"iterations": 1,
"iterations": 2,
"scaling": [8, 4, 4, 2, 2, 1, 1],
"smoothing": [0, 1, 0, 1, 0, 1, 0],
"constrained": true,
"opt_step": 0.1,
"stage_1a": {
"xtol_abs": [0.01, 0.01, 0.01, 0.01, 0.01]
},
"stage_1b": {
"xtol_abs": [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]
},
"x0": {
"scale2d": [1.0, 1.0],
"rot2d": -90,
......@@ -77,16 +82,16 @@
"dx0": {
"scale2d_lower_delta": [0.2, 0.2],
"scale2d_upper_delta": [0.2, 0.2],
"rot2d_lower_delta": 30.0,
"rot2d_upper_delta": 30.0,
"trans2d_lower_delta": [10.0, 10.0],
"trans2d_upper_delta": [10.0, 10.0],
"rot2d_lower_delta": 10.0,
"rot2d_upper_delta": 10.0,
"trans2d_lower_delta": [15.0, 15.0],
"trans2d_upper_delta": [15.0, 15.0],
"warp_lower_dxy": 2.5,
"warp_lower_dz": 2.5,
"warp_upper_dxy": 2.5,
"warp_upper_dz": 2.5,
"rot3d_lower_delta": [20.0, 20.0, 20.0],
"rot3d_upper_delta": [20.0, 20.0, 20.0],
"rot3d_lower_delta": [15.0, 15.0, 15.0],
"rot3d_upper_delta": [15.0, 15.0, 15.0],
"affine3d_lower_delta": [0.1, 0.1, 0.1, 2.5, 0.1, 0.1, 0.1, 2.5, 0.1, 0.1, 0.1, 2.5],
"affine3d_upper_delta": [0.1, 0.1, 0.1, 2.5, 0.1, 0.1, 0.1, 2.5, 0.1, 0.1, 0.1, 2.5],
"trans3d_lower_delta": [0.3, 0.3, 0.3],
......@@ -116,7 +121,7 @@
"image": true,
"snapshot": true
},
"visualise": true,
"visualise": false,
"verbose": 4,
"scaling": [4, 2, 1, 1],
"smoothing": [0, 0, 1, 0],
......@@ -129,33 +134,46 @@
"export": {
"image": true,
"snapshot": true,
"halton": true
"cpoints": true,
"mask": true
},
"mask": {
"lthr": 0.1,
"uthr": 1
},
"n_points": 16,
"visualise": true,
"adaptive": {
"max_points": 32
},
"visualise": false,
"verbose": 4,
"smoothing": [2, 1, 0],
"lower_dz": 3,
"upper_dz": 3,
"lower_dz": 1.5,
"upper_dz": 1.5,
"reg_weight": 0,
"opt_step": 0.2,
"xtol_abs": 0.1,
"randomseed": 1
"opt_step": 0.4,
"xtol_abs": 0.1
},
"stage_4": {
"export": {
"image": true,
"snapshot": true,
"halton": true
"cpoints": true,
"mask": true
},
"mask": {
"lthr": 0.1,
"uthr": 1
},
"adaptive": {
"max_points": 32
},
"n_points": 16,
"visualise": false,
"verbose": 4,
"smoothing": [2],
"smoothing": [2, 1, 0],
"lower_dxy": 3,
"lower_dz": 3,
"lower_dz": 1,
"upper_dxy": 3,
"upper_dz": 3,
"upper_dz": 1,
"reg_weight": 0,
"opt_step": 0.2,
"xtol_abs": 0.1
......@@ -164,13 +182,13 @@
"export": {
"image": true,
"snapshot": true,
"mask": true
"mask": true
},
"visualise": false,
"mask": {
"lthr": 0.1,
"uthr": 1
}
"mask": {
"lthr": 0.1,
"uthr": 1
}
}
}
}
......@@ -4,7 +4,7 @@
# This script is based on the Tensor Image Registration Library, which is part
# of the FMRIB Software Library (FSL).
# Author: Istvan N. Huszar, M.D. <istvan.huszar@dtc.ox.ac.uk>
# Date: 17 Oct 2019
# Date: 24 Feb 2020
# DEPENDENCIES
......@@ -12,9 +12,7 @@
import os
import sys
import json
import ghalton
import logging
import warnings
import argparse
import numpy as np
from time import time
......@@ -22,7 +20,6 @@ from operator import mul
from functools import reduce
from attrdict import AttrMap
from collections import namedtuple
from skimage.filters import threshold_otsu
# TIRL IMPORTS
......@@ -35,9 +32,6 @@ from tirl.domain import Domain
from tirl.tfield import TField
from tirl.timage import TImage
# Image operator
from tirl.operations.tensor import TensorOperator
# Linear transformations
from tirl.transformations.basic.identity import TxIdentity
from tirl.transformations.basic.embedding import TxEmbed
......@@ -58,103 +52,22 @@ from tirl.regularisers.membrane_energy import RegMembraneEnergy
# Optimiser
from tirl.optimisers.optnl import OptNL
from tirl.optimisers.optnl import OptMethods as om
from tirl.optimisers.optimiser import OptimisationGroup
# DEFINITIONS
# DEFINITIONS
# Transformation chain
Chain = namedtuple(
"Chain", ["tx_scale2d", "tx_rot2d", "tx_trans2d", "tx_embed",
"tx_warp", "tx_rot3d", "tx_affine3d", "tx_trans3d"])
# Default number of control points used for the through-plane curvature
# estimation step.
N_CONTROL_POINTS = 36
# DEFAULT OPTIMISATION PARAMETERS
# Note: this is a duplicate of the concomitant configuration file. Changing
# default parameters in the script will not affect default settings, unless an
# option has been removed from the configuration file. Beware that if this
# happens inadvertently, default parameter values will be sourced from here
# without further notification.
# Registration stages to run
STAGES = [1, 2, 3, 4, 5]
# Initial transformation parameters
DEFAULT_INITIALISATION = {
"scale2d": [1.0, 1.0], # scale: 100% vertical, 100% horizontal
"rot2d": -90, # -90 degree in-plane rotation
"trans2d": [0, 0], # no in-plane translation
"rot3d": [0, 0, 90], # 90 degree rotation of the 2D plane in 3D space
"affine3d": [1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0], # identity; no shear/scale/rotation/translation
"trans3d": [0, 0, 0] # no 3D translation
}
# Initial lower and upper bounds for each transformation parameter. Depending
# on the optimiser these may be either:
# - hard bounds (values that fall outside of the specified range are
# inaccessible), or
# - soft bounds (values outside the range are possible, but harder to reach
# due to renormalisation to the expected range).
DEFAULT_BOUND_DELTAS = {
"scale2d_lower_delta": [0.4, 0.4],
"scale2d_upper_delta": [0.4, 0.4],
"rot2d_lower_delta": 20,
"rot2d_upper_delta": 20,
"trans2d_lower_delta": [10, 10],
"trans2d_upper_delta": [10, 10],
"warp_lower_dxy": 2.5,
"warp_lower_dz": 2.5,
"warp_upper_dxy": 2.5,
"warp_upper_dz": 2.5,
"rot3d_lower_delta": [10., 10., 10.],
"rot3d_upper_delta": [10., 10., 10.],
"affine3d_lower_delta": [0.1, 0.1, 0.1, 2.5] * 3,
"affine3d_upper_delta": [0.1, 0.1, 0.1, 2.5] * 3,
"trans3d_lower_delta": [2.5, 2.5, 2.5],
"trans3d_upper_delta": [2.5, 2.5, 2.5]
}
# Absolute updates to the lower and upper bounds of the transformation
# parameters. When multiple optimisations are performed sequentially, these
# allow parameter bounds to evolve with the result through a series of
# optimisations.
UPDATE_BOUND_DELTAS = {
"scale2d_lower_delta" : [0.1, 0.1],
"scale2d_upper_delta" : [0.1, 0.1],
"rot2d_lower_delta" : 5,
"rot2d_upper_delta" : 5,
"trans2d_lower_delta" : [2.5, 2.5],
"trans2d_upper_delta" : [2.5, 2.5],
"warp_lower_dxy": 2.,
"warp_lower_dz": 1.,
"warp_upper_dxy": 2.,
"warp_upper_dz": 1.,
"rot3d_lower_delta" : [5., 5., 5.],
"rot3d_upper_delta" : [5., 5., 5.],
"affine3d_lower_delta": [0.1, 0.1, 0.1, 1.] * 3,
"affine3d_upper_delta": [0.1, 0.1, 0.1, 1.] * 3,
"trans3d_lower_delta" : [1., 1., 1.],
"trans3d_upper_delta" : [1., 1., 1.]
}
# NumPy print formatting
np.set_printoptions(precision=4)
# IMPLEMENTATION
# This is the implementation of the slice-to-volume image registration pipeline.
# The pipeline consists of the following functions:
# run:
# register:
def run(cnf=None, **options):
"""
......@@ -212,8 +125,7 @@ def run(cnf=None, **options):
# Load input images and perform pre-registration actions if the first
# stage is stage 1.
stages = p.general.stages or STAGES
if stages[0] == 1:
if p.general.stages[0] == 1:
# Load slice (2D image)
img = load_slice(p.slice.file, resolution=p.slice.resolution)
......@@ -408,13 +320,10 @@ def register(img, vol, parameters):
"""
rp = AttrMap(parameters)
# Decide which stages to run
stages = p.general.stages or STAGES
# Check prerequisites first (better to fail here than later)
assert isinstance(img, TImage)
assert isinstance(vol, TImage)
for stage_no in stages:
for stage_no in p.general.stages:
dict(rp).get(f"stage_{stage_no}")
# Initialise output directory
......@@ -435,30 +344,20 @@ def register(img, vol, parameters):
json.dump(dict(p), f, indent=4)
# Run registration stages
for i, stage_no in enumerate(stages):
# The non-linear transformation should be added before running
# stages 3 and 4
if stage_no == 3:
img = update_registration_frame(
img, warpaxes=(2,), stparams=rp.stage_3)
elif stage_no == 4:
img = update_registration_frame(
img, warpaxes=(0, 1, 2), stparams=rp.stage_4)
# Run current stage
run_stage(img, vol, stage_no, i)
for stage_no in p.general.stages:
p.general.stage_index = int(p.general.stage_index) + 1
run_stage(img, vol, stage_no)
return img
def run_stage(img, vol, stage_no, i):
def run_stage(img, vol, stage_no):
"""
Runs a registration stage specified by its number.
"""
logger.info(f"STAGE {stage_no}...")
stage_index = int(p.general.stage_index) + i
stage_index = p.general.stage_index
st = time()
func = f"stage{stage_no}"
stparams = p.regparams.get(f"stage_{stage_no}")
......@@ -488,18 +387,18 @@ def run_stage(img, vol, stage_no, i):
vol.evaluate(img.domain).snapshot(os.path.join(
p.general.outdir, f"{stage_index}_stage{stage_no}{ext}"), True)
# Export snapshot with Halton points (where applicable)
haltonspec = stparams["export"].get("halton", None)
if haltonspec:
# Export snapshot with control points (where applicable)
cpointspec = stparams["export"].get("cpoints", None)
if cpointspec:
tx_warp = img.domain.get_transformation("warp")
cpoints = tx_warp.domain.get_voxel_coordinates()
if isinstance(haltonspec, str):
fp, fn = os.path.split(haltonspec)
if isinstance(cpointspec, str):
fp, fn = os.path.split(cpointspec)
filename = os.path.join(fp, f"{stage_index}_{fn}")
else:
filename = os.path.join(
p.general.outdir, f"{stage_index}_control_points.png")
_save_halton_image(img, cpoints, filename)
_save_control_point_image(img, cpoints, filename)
# Visualise end-stage alignment
if stparams["visualise"]:
......@@ -543,9 +442,8 @@ def initialise_registration_frame(img):
img.domain.transformations.extend(postwarp)
def update_registration_frame(img, warpaxes, stparams):
def update_registration_frame(img, cost, warpaxes, stparams):
randomseed = p.general.randomseed or ts.HALTON_SEED
tx_warp, i = img.domain.get_transformation("warp", index=True)
# If the warp transformation is a genuine TxRbfDisplacementField
......@@ -558,15 +456,15 @@ def update_registration_frame(img, warpaxes, stparams):
return img
# Preserve the parameters and metaparameters of the previous warp
new_field = TField(tx_warp.domain, tensor_shape=len(warpaxes))
new_field = TField(tx_warp.domain, tensor_shape=len(warpaxes),
order=ts.VOXEL_MAJOR)
for j, ax in enumerate(old_axes):
try:
k = warpaxes.index(ax)
except ValueError:
break
else:
# TODO: WHY k on the RHS, not j???
new_field.tensors[k][...] = old_field.tensors[k].data
new_field.tensors[k][...] = old_field.tensors[j].data
# If the warp transformation is a TxIdentity placeholder
else:
......@@ -576,51 +474,103 @@ def update_registration_frame(img, warpaxes, stparams):
# Create TField from scratch
else:
n_points = stparams.get("n_points", N_CONTROL_POINTS)
try:
# Convert points to y, x from x, y!
points = np.loadtxt(stparams.get("points", None))[:, ::-1]
except Exception:
points = calculate_control_points(n_points, img, randomseed)
points = calculate_adaptive_control_points(img, cost, stparams)
points = np.round(points).astype(np.int)
rbf_domain = Domain(points, *img.domain.transformations[:i],
offset=img.domain.offset, name="rbf_domain")
new_field = TField(rbf_domain, tensor_shape=len(warpaxes))
new_field = TField(rbf_domain, tensor_shape=len(warpaxes),
order=ts.VOXEL_MAJOR)
new_field.interpolator.model = "gaussian"
# Update the TxRbfDisplacementField transformation in the chain
new_tx = TxRbfDisplacementField(
new_field, axes=warpaxes, name="warp", model="gaussian")
new_tx = TxRbfDisplacementField(new_field, axes=warpaxes, name="warp")
img.domain.transformations[i] = new_tx
return img
def calculate_control_points(n_points, img, randomseed):
def calculate_adaptive_control_points(img, cost, stparams):
"""
Returns a table of optimised voxel coordinates for the control points of
the warp transformation.
:param n_points:
Number of control points.
:type n_points: int
:param img:
Image to be registered.
Uses superpixel segmentation to generate an approximate number of tiles,
distributed evenly and compactly across the ROI of the target image, as
defined by a mask.
:param img: target 2D image
:type img: TImage
:param vol: source 3D image
:type vol: TImage
:param stparams: stage-specific parameters
:type stparams: Any
:returns:
(n_points, 2) table of optimised voxel coordinates for control point
placement.
:returns: (n_points, 2) voxel coordinate array of control points
:rtype: np.ndarray
"""
tensormean = TensorOperator(np.mean, axis=tuple(range(1, 1 + img.tdim)))
flatimg = tensormean(img.asTField).data
# TODO: This is not compatible with HDD arrays
th = threshold_otsu(flatimg)
flatimg = np.squeeze(np.where(flatimg < th, 0, 1))
bbox = _get_bbox(flatimg)
points = _halton(n_points, randomseed, **bbox)[:, ::-1] # revert x and y
return points
from skimage.exposure import rescale_intensity
from skimage.measure import regionprops
from skimage.segmentation import slic
# from skimage.future import graph
t = AttrMap(stparams)
# Create cost function map and a ROI mask for parcellation
mask = np.logical_and(img.data > t.mask.lthr, img.data <= t.mask.uthr)
maskspec = t.export.mask
if maskspec:
if isinstance(maskspec, str):
fp, fn = os.path.split(maskspec)
filename = os.path.join(fp, f"{maskspec}_{fn}")
else:
filename = os.path.join(
p.general.outdir,
f"{p.general.stage_index}_mask.{p.general.snapshot_ext}")
TImage.fromarray(mask.astype(np.int64)).snapshot(filename)
try:
cm = cost.costmap().reduce_tensors(copy=True)
except:
# If the cost function does not support voxelwise mapping
logger.debug("Costmap not available. Parcellation is based on target "
"mask.")
cm = TImage.fromarray(
np.asarray(img.data > t.mask.lthr, dtype=np.float64))