Commit 6baf55e4 authored by inhuszar's avatar inhuszar
Browse files

Avoided chain duplication on continued registration.

parent d564f1ae
......@@ -204,7 +204,15 @@ def initialise_transformations(fixed, p):
"""
q = p.regparams
try:
ix0 = fixed.domain.get_transformation("centralise", index=True)[1]
except:
ix0 = max(len(fixed.domain.chain) - 1, 0)
# Scale
try:
tx_scale = fixed.domain.get_transformation("scale")
except:
lb = np.asarray(q.init.scale.lb)
ub = np.asarray(q.init.scale.ub)
if p.general.isotropic:
......@@ -213,8 +221,12 @@ def initialise_transformations(fixed, p):
float(q.init.scale.x0), dim=2, bounds=bounds, name="scale")
else:
tx_scale = TxScale(*q.init.scale.x0, bounds=(lb, ub), name="scale")
# fixed.domain.chain.append(tx_scale)
# Rotation
try:
tx_rotation = fixed.domain.get_transformation("rotation")
except:
if str(q.init.rotation.mode).lower() == "deg":
lb = radians(float(q.init.rotation.lb))
ub = radians(float(q.init.rotation.ub))
......@@ -224,25 +236,38 @@ def initialise_transformations(fixed, p):
tx_rotation = TxRotation2D(
float(q.init.rotation.x0), mode=q.init.rotation.mode,
bounds=(lb, ub), name="rotation")
# fixed.domain.chain.append(tx_rotation)
# Translation
try:
tx_trans = fixed.domain.get_transformation("translation")
except:
lb = np.asarray(q.init.translation.lb)
ub = np.asarray(q.init.translation.ub)
tx_trans = TxTranslation(
*q.init.translation.x0, bounds=(lb, ub), name="translation")
# fixed.domain.chain.append(tx_trans)
# Affine
try:
tx_affine, ix = fixed.domain.get_transformation("affine", index=True)
ix = len(fixed.domain.chain) - 1
except:
x0 = np.asarray(q.init.affine.x0).reshape((2, 3))
lb = np.asarray(q.init.affine.lb)
ub = np.asarray(q.init.affine.ub)
tx_affine = TxAffine(x0, bounds=(lb, ub), name="affine")
# fixed.domain.chain.append(tx_affine)
# Append linear transformations to the domain of the fixed image
linear_chain = Chain(tx_rotation, tx_scale, tx_trans, tx_affine)
domain = fixed.domain[:]
domain = fixed.domain[:ix0 + 1]
domain.chain.extend(linear_chain)
# Nonlinear
try:
tx_nonlinear = fixed.domain.get_transformation("nonlinear")
except:
x0 = float(q.init.nonlinear.x0) * np.ones((2, *fixed.vshape))
if q.init.nonlinear.lb is None:
lb = None
......@@ -253,12 +278,35 @@ def initialise_transformations(fixed, p):
else:
ub = float(q.init.nonlinear.ub) * np.ones_like(x0)
field = TField.fromarray(
x0, tensor_axes=(0,), copy=False, domain=domain[:], order=TENSOR_MAJOR)
x0, tensor_axes=(0,), copy=False, domain=domain[:],
order=TENSOR_MAJOR)
tx_nonlinear = TxDisplacementField(
field, bounds=(lb, ub), name="nonlinear", mode=NL_REL)
# fixed.domain.chain.append(tx_nonlinear)
# Nonlinear 2
try:
tx_nonlinear2 = fixed.domain.get_transformation("nonlinear2")
except:
x0 = float(q.init.nonlinear.x0) * np.ones((2, *fixed.vshape))
if q.init.nonlinear.lb is None:
lb = None
else:
lb = float(q.init.nonlinear.lb) * np.ones_like(x0)
if q.init.nonlinear.ub is None:
ub = None
else:
ub = float(q.init.nonlinear.ub) * np.ones_like(x0)
field2 = TField.fromarray(
x0, tensor_axes=(0,), copy=False, domain=domain[:],
order=TENSOR_MAJOR)
tx_nonlinear2 = TxDisplacementField(
field2, bounds=(lb, ub), name="nonlinear2", mode=NL_REL)
# fixed.domain.chain.append(tx_nonlinear2)
# Return the full transformation chain
return Chain(*linear_chain, tx_nonlinear)
return fixed.domain.chain[:ix0 + 1] + \
Chain(*linear_chain, tx_nonlinear, tx_nonlinear2)
def register(fixed, moving, cnf):
......@@ -286,8 +334,12 @@ def register(fixed, moving, cnf):
# rotation -> scale -> translation -> affine -> nonlinear
logger.info("Initialising transformation chain...")
chain = initialise_transformations(fixed, p)
fixed.domain = fixed.domain[:0]
logger.info("Transformation chain has been initialised.")
# Set the first part of the chain
fixed.domain.chain.extend(chain[:-3])
# Generate output: initial alignment
fixed.save(os.path.join(
p.general.outputdir, "fixed.timg"), overwrite=True)
......@@ -298,9 +350,6 @@ def register(fixed, moving, cnf):
moving.snapshot(os.path.join(
p.general.outputdir, f"moving.{SNAPSHOT_EXT}"), overwrite=True)
# Set the first part of the chain
fixed.domain.chain.extend(chain[:-2])
# Rotation search
if "rotation" in p.general.stages:
logger.info("Starting rotation search...")
......@@ -330,7 +379,7 @@ def register(fixed, moving, cnf):
logger.info("Rigid registration was skipped.")
# Affine registration
fixed.domain.chain.append(chain[-2])
fixed.domain.chain.append(chain[-3])
if "affine" in p.general.stages:
logger.info("Starting affine registration...")
affine2d(fixed, moving, p)
......@@ -345,7 +394,7 @@ def register(fixed, moving, cnf):
logger.info("Affine registration was skipped.")
# Non-linear registration
tx_nonlinear = chain[-1]
tx_nonlinear = chain[-2]
tx_nonlinear.domain.chain = fixed.domain.chain[:]
fixed.domain.chain.append(tx_nonlinear)
if "nonlinear" in p.general.stages:
......@@ -361,6 +410,23 @@ def register(fixed, moving, cnf):
else:
logger.info("Non-linear registration was skipped.")
# Non-linear registration 2
tx_nonlinear2 = chain[-1]
tx_nonlinear2.domain.chain = fixed.domain.chain[:-1]
fixed.domain.chain.append(tx_nonlinear2)
if "nonlinear2" in p.general.stages:
logger.info("Starting the 2nd non-linear registration...")
diffreg2d(fixed, moving, p)
logger.info("Completed the 2nd non-linear registration.")
# Generate output
fixed.save(os.path.join(
p.general.outputdir, "fixed5_nonlinear2.timg"), overwrite=True)
moving.evaluate(fixed.domain).snapshot(os.path.join(
p.general.outputdir, f"moving5_nonlinear2.{SNAPSHOT_EXT}"),
overwrite=True)
else:
logger.info("The 2nd non-linear registration was skipped.")
# try:
# os.remove(os.path.join(p.general.outdir, "histology.timg"))
# except:
......@@ -552,16 +618,18 @@ def diffreg2d(fixed, moving, cnf):
for i, (sc, sm) in enumerate(zip(q.scaling, q.smoothing)):
logger.debug(f"Scale: {sc}, smoothing: {sm} px...")
# Prepare images for the current iteration
fixed.resample(1 / sc, copy=False)
moving.resample(1 / sc, copy=False)
fixed.resample(1 / sc, copy=False, update_chain=False)
moving.resample(1 / sc, copy=False, update_chain=False)
fixed_smooth = fixed.smooth(sm, copy=True)
moving_smooth = moving.smooth(sm, copy=True)
# Prepare transformation to optimise
tx_nonlinear = fixed_smooth.domain.chain[-1]
tx_nonlinear = tx_nonlinear.regrid(fixed_smooth.domain[:-2])
fixed_smooth.domain.chain[-1] = tx_nonlinear
# Set cost and regulariser
cost = CostMIND(moving_smooth, fixed_smooth, sigma=float(q.sigma),
truncate=float(q.truncate), kernel=MK_FULL,
ignore_masked_edges=True)
ignore_masked_edges=True, ignore_masked_regions=True)
regularisation = DiffusionRegulariser(
tx_nonlinear, weight=float(q.regweight))
# Optimise the non-linear transformation
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment