Commit 64d74cdd authored by Sean Fitzgibbon's avatar Sean Fitzgibbon
Browse files

Added support for masking using GMM

parent c6508b74
......@@ -28,11 +28,12 @@ moving:
mask:
automask:
thr: 0.0
uthr: 1.0
uthr: 1
file: null
function: null
# function: null
function: gmm_mask
normalise: true
preview: false
preview: true
resolution_level: 4
snapshot: true
storage: mem
......@@ -40,16 +41,17 @@ moving:
# fixed image settings
fixed:
dtype: f4
export: false
export: true
file: ''
mask:
automask:
thr: 0.0
uthr: 1.0
thr: 0
uthr: 1
file: null
function: null
function: gmm_mask
# function: null
normalise: true
preview: false
preview: true
resolution_level: 4
snapshot: true
storage: mem
......@@ -58,6 +60,8 @@ fixed:
preprocessing:
fixed: ["fixed_preprocessing"]
moving: ["moving_preprocessing", "match_fixed_resolution"]
# fixed: null
# moving: ["match_fixed_resolution"]
# registration settings
regparams:
......@@ -76,6 +80,9 @@ regparams:
x0: [0.0, 0.0]
lb: [-5.0, -5.0]
ub: [5.0, 5.0]
# restrict lower and upper bound
# lb: [-1.0, -1.0]
# ub: [1.0, 1.0]
affine:
x0: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]
lb: [0.95, -0.2, -1.0, -0.2, 0.95, -1.0]
......@@ -104,21 +111,28 @@ regparams:
affine:
opt_step: 0.1
scaling: [160, 80, 40, 20, 10, 5]
smoothing: [0, 0, 0]
smoothing: [0, 0, 0, 0, 0, 0, 0]
visualise: false
xtol_abs: [0.001, 0.001, 0.001, 0.001, 0.001, 0.001]
xtol_rel: 0.01
nonlinear:
maxiter: [20, 20, 20, 20, 10, 5]
regweight: 0.4
maxiter: [20, 20, 20, 20, 10, 10]
# more regularised
# regweight: 0.4
regweight: 0.6
scaling: [160, 80, 40, 20, 10, 5]
# scaling: [80, 60, 40, 20, 10, 5]
# can try increasing (1.5, 2, ) number of pixels
# radius of smoothing kernel (number of pixels)
sigma: 1
smoothing: [0, 0, 0, 0, 0, 0]
truncate: 1.5
visualise: false
# lower will lead to longer registration, might not converge if too small
xtol_abs: 0.1
xtol_rel: 0.01
xtol_rel: 0.01 #mm median magnitude
kernel: MK_FULL
......
......@@ -100,6 +100,7 @@ __tirlscript__ = True
import argparse
import json, yaml
from numpy.lib.shape_base import _kron_dispatcher
import logging
import os
import sys
......@@ -129,6 +130,11 @@ from tirl.transformations.linear.scale import TxScale, TxIsoScale
from tirl.transformations.linear.translation import TxTranslation
from tirl.transformations.nonlinear.displacement import TxDisplacementField
import matplotlib
matplotlib.use("TkAgg")
print(matplotlib.get_backend())
import matplotlib.pyplot as plt
import os.path as op
......@@ -163,19 +169,7 @@ def _load_image(p):
if p.file.lower().endswith(".jp2"): # if jpg2k
# # load jp2
# import glymur
# from tirl.scripts.mnd.image import set_mask
# jp2= glymur.Jp2k(p.file)
# img = jp2.read(rlevel=p.resolution_level)
# # adjust resolution by resolution_level
# p.resolution = p.resolution * (2**p.resolution_level)
# # create timg
# timg = TImage.fromarray(img, dtype=p.dtype, tensor_axes=(2,))
# timg.resolution = p.resolution
# load jp2
timg = _load_jpg2k(p.file, p.resolution, p.resolution_level, p.dtype)
# add mask
......@@ -228,8 +222,6 @@ def run(cnf=None, **options):
ext = ts.EXTENSIONS["TImage"]
p.moving.export = os.path.join(p.general.outputdir, f"moving.{ext}")
# TODO: adjust resolution based on rlevel.
moving = _load_image(p.moving)
if p.fixed.export is True:
......@@ -250,8 +242,19 @@ def run(cnf=None, **options):
moving = tirl.scripts.mnd.image.perform_image_operations(
moving, *p.preprocessing.moving, scope=globals(), other=fixed, cnf=p
)
# Initialise registration frame
moving.centralise(weighted=True)
moving.centralise(weighted=False)
moving.snapshot(
os.path.join(p.general.outputdir, f"moving0_centralised.{SNAPSHOT_EXT}"),
overwrite=True,
)
if moving.mask is not None:
TImage.fromarray(moving.mask).snapshot(
os.path.join(p.general.outputdir, f"moving0_mask.{SNAPSHOT_EXT}"),
overwrite=True,
)
# Perform actions on the fixed image prior to registration, unless it was
# loaded from a TImage file.
......@@ -259,10 +262,22 @@ def run(cnf=None, **options):
(ts.EXTENSIONS["TImage"], ts.EXTENSIONS["TIRLObject"])
)
if not isalternative:
fixed = tirl.scripts.mnd.image.perform_image_operations(
fixed, *p.preprocessing.fixed, scope=globals(), other=moving, cnf=p
)
fixed.centralise(weighted=True)
fixed.centralise(weighted=False)
fixed.snapshot(
os.path.join(p.general.outputdir, f"fixed0_centralised.{SNAPSHOT_EXT}"),
overwrite=True,
)
if fixed.mask is not None:
TImage.fromarray(fixed.mask).snapshot(
os.path.join(p.general.outputdir, f"fixed0_mask.{SNAPSHOT_EXT}"),
overwrite=True,
)
# Run the registration routine
try:
......@@ -275,7 +290,7 @@ def run(cnf=None, **options):
logger.fatal("The registration was completed successfully.")
def labkmeans(histo, **kwargs):
def labkmeans(histo, n_clusters=2, **kwargs):
"""
Segments the foreground (tissue) in a histological slide and sets it as the
TImage mask.
......@@ -296,13 +311,16 @@ def labkmeans(histo, **kwargs):
)
lab_a = rgb2lab(imdata)[..., 1]
X = lab_a.reshape(-1, 1)
km = KMeans(n_clusters=2, random_state=0).fit(X)
km = KMeans(n_clusters=n_clusters, random_state=0).fit(X)
kc = km.cluster_centers_
mask = km.labels_.reshape(histo.vshape)
if kc[0] > kc[1]: # make sure that the lower intensity is labeled 0
mask = 1 - mask
# mask = km.labels_.reshape(histo.vshape)
# if kc[0] > kc[1]: # make sure that the lower intensity is labeled 0
# mask = 1 - mask
mask = km.labels_.reshape(histo.vshape) != np.argmin(kc)
histo.order = orig_order
return mask
return mask.astype(int)
def initialise_transformations(fixed, p):
......@@ -491,6 +509,11 @@ def register(fixed, moving, cnf):
else:
logger.info("Affine registration was skipped.")
# mask_and(moving, fixed)
# moving.mask = None
# fixed.mask = None
# Non-linear registration
tx_nonlinear = chain[-1]
tx_nonlinear.domain.chain = fixed.domain.chain[:]
......@@ -545,7 +568,6 @@ def rotation_search2d(fixed, moving, cnf):
tx_rotation.parameters.set_lower_bounds(angle - step / 2)
tx_rotation.parameters.set_upper_bounds(angle + step / 2)
cost = CostMIND(moving, fixed, normalise=True, kernel=MK_FULL)()
# cost = CostMI(moving, fixed, normalise=True, bins=32)()
logger.debug(f"{degrees(angle)} deg: {cost}")
costvals.append([cost, angle])
else:
......@@ -579,7 +601,6 @@ def rotation_search2d(fixed, moving, cnf):
moving.resample(float(q.scale), copy=False)
# Set cost function
cost = CostMIND(moving, fixed, maskmode="and", normalise=True)
# cost = CostMI(moving, fixed, maskmode="and", normalise=True)
# Start optimisation
logger.info(
f"Co-optimising scale and translation at " f"{degrees(angle)} deg..."
......@@ -637,8 +658,17 @@ def rigid2d(fixed, moving, cnf):
lb = lb - np.finfo(lb.dtype).eps
ub = ub + np.finfo(lb.dtype).eps
og.set_bounds(lb, ub)
# binarise mask
if moving_smooth.mask is not None:
moving_smooth.mask = (moving_smooth.mask > 0.5).astype(int)
if fixed_smooth.mask is not None:
fixed_smooth.mask = (fixed_smooth.mask > 0.5).astype(int)
# Set cost function
cost = CostMIND(moving_smooth, fixed_smooth, normalise=True)
cost = CostMIND(moving_smooth, fixed_smooth, normalise=True, maskmode="and")
# Start optimisation
OptNL(
og,
......@@ -681,7 +711,16 @@ def affine2d(fixed, moving, cnf):
moving_smooth = moving.smooth(sm, copy=True)
# Prepare transformation to optimise
tx_affine = fixed_smooth.domain.chain["affine"]
cost = CostMIND(moving_smooth, fixed_smooth, normalise=True)
# binarise mask
if moving_smooth.mask is not None:
moving_smooth.mask = (moving_smooth.mask > 0.5).astype(int)
if fixed_smooth.mask is not None:
fixed_smooth.mask = (fixed_smooth.mask > 0.5).astype(int)
# Set cost function
cost = CostMIND(moving_smooth, fixed_smooth, normalise=True, maskmode="and")
# Start optimisation
OptNL(
tx_affine,
......@@ -722,16 +761,42 @@ def diffreg2d(fixed, moving, cnf):
moving.resample(1 / sc, copy=False)
fixed_smooth = fixed.smooth(sm, copy=True)
moving_smooth = moving.smooth(sm, copy=True)
# binarise mask (if req'd)
if moving_smooth.mask is not None:
moving_smooth.mask = (moving_smooth.mask > 0.5).astype(int)
if fixed_smooth.mask is not None:
fixed_smooth.mask = (fixed_smooth.mask > 0.5).astype(int)
# Prepare transformation to optimise
tx_nonlinear = fixed_smooth.domain.chain[-1]
# Set cost and regulariser
if q.kernel == "MK_FULL":
kernel = MK_FULL
elif q.kernel == "MK_STAR":
kernel = MK_STAR
else:
raise RuntimeError(f"Unknown Cost Kernel: {q.kernel}")
cost = CostMIND(
moving_smooth,
fixed_smooth,
sigma=float(q.sigma),
truncate=float(q.truncate),
kernel=MK_FULL,
# kernel=MK_FULL,
# kernel=MK_STAR,
kernel=kernel,
maskmode="or",
)
# np.save(
# os.path.join(p.general.outputdir, f"cost_sc{sc}.npy"),
# cost.costmap().data,
# )
regularisation = DiffusionRegulariser(tx_nonlinear, weight=float(q.regweight))
# Optimise the non-linear transformation
GNOptimiserDiffusion(
......@@ -747,6 +812,15 @@ def diffreg2d(fixed, moving, cnf):
# Transfer optimised transformations to the non-smoothed images
fixed.domain = fixed_smooth.domain
moving.domain = moving_smooth.domain
# fixed_smooth.snapshot(
# os.path.join(p.general.outputdir, f"fixed_sc{sc}.{SNAPSHOT_EXT}"), overwrite=True,
# )
# moving_smooth.evaluate(fixed_smooth.domain).snapshot(
# os.path.join(p.general.outputdir, f"moving_sc{sc}.{SNAPSHOT_EXT}"), overwrite=True,
# )
else:
# Restore the original resolution of the images
fixed.resample(1, copy=False)
......@@ -761,6 +835,72 @@ def diffreg2d(fixed, moving, cnf):
# included here to prevent too frequent code repetitions.
def gmm_mask(img, gmm_comps=3, dilation_radius=2):
from skimage.measure import regionprops, label
from skimage.color import rgb2gray
from skimage import filters, transform, exposure, morphology
from scipy import ndimage as ndi
from sklearn import mixture
img0 = img.data.astype(np.uint8)
img0 = rgb2gray(img0)
full_sz = img0.shape
# downsample
# img = transform.rescale(img, (1.0/2, 1.0/2), order=0, anti_aliasing=True)
# rescale intensity
p_min, p_max = np.percentile(img0, (1, 99))
img0 = exposure.rescale_intensity(img0, in_range=(p_min, p_max))
# use guassian mixture model to segment tissues
# clf = mixture.BayesianGaussianMixture(
# n_components=gmm_comps,
# covariance_type='full',
# )
clf = mixture.GaussianMixture(
n_components=gmm_comps,
)
clf.fit(img0.reshape((-1, 1)))
pred = clf.predict(img0.reshape((-1, 1)))
pred = pred.reshape(img0.shape)
# note: this assumes a bright background (argmin should be used for dark BG)
target = [np.mean(img0[pred == idx]) for idx in range(gmm_comps)]
mask = pred != np.argmax(target)
# calculate connected components
cc = label(mask, background=0)
# get region properties for each cc and select the largest cc (in terms of area)
p = regionprops(cc, img0)
ridx = np.argmax([p0.area for p0 in p])
p = p[ridx]
# create mask
mask = cc == p.label
# fill holes
mask = ndi.binary_fill_holes(mask)
# dilate
mask = morphology.binary_opening(mask, morphology.disk(10, dtype=bool))
# mask = morphology.binary_dilation(mask, morphology.disk(5, dtype=bool))
# upsample
# mask = transform.resize(mask, full_sz, order=0)
return mask.astype(int)
def rgb2hsv(x):
assert x.shape[-1] == 3, "The input must be RGB."
_x = x / 255
......@@ -907,12 +1047,31 @@ def dilated_object_mask(timg, **kwargs):
return objmask.astype(np.float32)
def mask_and(moving, fixed):
from skimage import transform
moving_mask = moving.evaluate(fixed.domain).mask
fixed_mask = fixed.mask
moving_mask = transform.resize(
moving_mask.astype(bool), fixed_mask.shape, order=0, preserve_range=True
)
mask = np.logical_not((moving_mask > 0) != (fixed_mask > 0))
TImage(mask).preview()
fixed.mask = mask.astype(int)
moving.mask = np.ones(moving.vshape, dtype=int)
@beta_function
def mask_roi_defects(histo, block, p):
def mask_roi_defects(moving, fixed, p):
# Highlight non-matching areas between affine-registered images
tmp = histo.evaluate(block.domain)
tmp = moving.evaluate(fixed.domain)
tmp = np.where(tmp.data > 0.05, 1, 0)
binary = np.where(block.data > 1, 1, 0)
binary = np.where(fixed.data > 1, 1, 0)
totalarea = np.count_nonzero(binary)
# print(totalarea)
area = p.area * totalarea if p.area < 1 else p.area
......@@ -943,9 +1102,9 @@ def mask_roi_defects(histo, block, p):
# Add the identified ROI defects to the block mask
binary = 1 - binary
bmask = block.mask.data
bmask = fixed.mask.data
# bmask[bmask == 0.1] = 0
block.mask = bmask * binary
fixed.mask = bmask * binary
def pad(timg, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment