Commit 6baf55e4 authored by inhuszar's avatar inhuszar
Browse files

Avoided chain duplication on continued registration.

parent d564f1ae
...@@ -204,61 +204,109 @@ def initialise_transformations(fixed, p): ...@@ -204,61 +204,109 @@ def initialise_transformations(fixed, p):
""" """
q = p.regparams q = p.regparams
try:
ix0 = fixed.domain.get_transformation("centralise", index=True)[1]
except:
ix0 = max(len(fixed.domain.chain) - 1, 0)
# Scale # Scale
lb = np.asarray(q.init.scale.lb) try:
ub = np.asarray(q.init.scale.ub) tx_scale = fixed.domain.get_transformation("scale")
if p.general.isotropic: except:
bounds = (float(lb), float(ub)) lb = np.asarray(q.init.scale.lb)
tx_scale = TxIsoScale( ub = np.asarray(q.init.scale.ub)
float(q.init.scale.x0), dim=2, bounds=bounds, name="scale") if p.general.isotropic:
else: bounds = (float(lb), float(ub))
tx_scale = TxScale(*q.init.scale.x0, bounds=(lb, ub), name="scale") tx_scale = TxIsoScale(
float(q.init.scale.x0), dim=2, bounds=bounds, name="scale")
else:
tx_scale = TxScale(*q.init.scale.x0, bounds=(lb, ub), name="scale")
# fixed.domain.chain.append(tx_scale)
# Rotation # Rotation
if str(q.init.rotation.mode).lower() == "deg": try:
lb = radians(float(q.init.rotation.lb)) tx_rotation = fixed.domain.get_transformation("rotation")
ub = radians(float(q.init.rotation.ub)) except:
else: if str(q.init.rotation.mode).lower() == "deg":
lb = float(q.init.rotation.lb) lb = radians(float(q.init.rotation.lb))
ub = float(q.init.rotation.ub) ub = radians(float(q.init.rotation.ub))
tx_rotation = TxRotation2D( else:
float(q.init.rotation.x0), mode=q.init.rotation.mode, lb = float(q.init.rotation.lb)
bounds=(lb, ub), name="rotation") ub = float(q.init.rotation.ub)
tx_rotation = TxRotation2D(
float(q.init.rotation.x0), mode=q.init.rotation.mode,
bounds=(lb, ub), name="rotation")
# fixed.domain.chain.append(tx_rotation)
# Translation # Translation
lb = np.asarray(q.init.translation.lb) try:
ub = np.asarray(q.init.translation.ub) tx_trans = fixed.domain.get_transformation("translation")
tx_trans = TxTranslation( except:
*q.init.translation.x0, bounds=(lb, ub), name="translation") lb = np.asarray(q.init.translation.lb)
ub = np.asarray(q.init.translation.ub)
tx_trans = TxTranslation(
*q.init.translation.x0, bounds=(lb, ub), name="translation")
# fixed.domain.chain.append(tx_trans)
# Affine # Affine
x0 = np.asarray(q.init.affine.x0).reshape((2, 3)) try:
lb = np.asarray(q.init.affine.lb) tx_affine, ix = fixed.domain.get_transformation("affine", index=True)
ub = np.asarray(q.init.affine.ub) ix = len(fixed.domain.chain) - 1
tx_affine = TxAffine(x0, bounds=(lb, ub), name="affine") except:
x0 = np.asarray(q.init.affine.x0).reshape((2, 3))
lb = np.asarray(q.init.affine.lb)
ub = np.asarray(q.init.affine.ub)
tx_affine = TxAffine(x0, bounds=(lb, ub), name="affine")
# fixed.domain.chain.append(tx_affine)
# Append linear transformations to the domain of the fixed image # Append linear transformations to the domain of the fixed image
linear_chain = Chain(tx_rotation, tx_scale, tx_trans, tx_affine) linear_chain = Chain(tx_rotation, tx_scale, tx_trans, tx_affine)
domain = fixed.domain[:] domain = fixed.domain[:ix0 + 1]
domain.chain.extend(linear_chain) domain.chain.extend(linear_chain)
# Nonlinear # Nonlinear
x0 = float(q.init.nonlinear.x0) * np.ones((2, *fixed.vshape)) try:
if q.init.nonlinear.lb is None: tx_nonlinear = fixed.domain.get_transformation("nonlinear")
lb = None except:
else: x0 = float(q.init.nonlinear.x0) * np.ones((2, *fixed.vshape))
lb = float(q.init.nonlinear.lb) * np.ones_like(x0) if q.init.nonlinear.lb is None:
if q.init.nonlinear.ub is None: lb = None
ub = None else:
else: lb = float(q.init.nonlinear.lb) * np.ones_like(x0)
ub = float(q.init.nonlinear.ub) * np.ones_like(x0) if q.init.nonlinear.ub is None:
field = TField.fromarray( ub = None
x0, tensor_axes=(0,), copy=False, domain=domain[:], order=TENSOR_MAJOR) else:
tx_nonlinear = TxDisplacementField( ub = float(q.init.nonlinear.ub) * np.ones_like(x0)
field, bounds=(lb, ub), name="nonlinear", mode=NL_REL) field = TField.fromarray(
x0, tensor_axes=(0,), copy=False, domain=domain[:],
order=TENSOR_MAJOR)
tx_nonlinear = TxDisplacementField(
field, bounds=(lb, ub), name="nonlinear", mode=NL_REL)
# fixed.domain.chain.append(tx_nonlinear)
# Nonlinear 2
try:
tx_nonlinear2 = fixed.domain.get_transformation("nonlinear2")
except:
x0 = float(q.init.nonlinear.x0) * np.ones((2, *fixed.vshape))
if q.init.nonlinear.lb is None:
lb = None
else:
lb = float(q.init.nonlinear.lb) * np.ones_like(x0)
if q.init.nonlinear.ub is None:
ub = None
else:
ub = float(q.init.nonlinear.ub) * np.ones_like(x0)
field2 = TField.fromarray(
x0, tensor_axes=(0,), copy=False, domain=domain[:],
order=TENSOR_MAJOR)
tx_nonlinear2 = TxDisplacementField(
field2, bounds=(lb, ub), name="nonlinear2", mode=NL_REL)
# fixed.domain.chain.append(tx_nonlinear2)
# Return the full transformation chain # Return the full transformation chain
return Chain(*linear_chain, tx_nonlinear) return fixed.domain.chain[:ix0 + 1] + \
Chain(*linear_chain, tx_nonlinear, tx_nonlinear2)
def register(fixed, moving, cnf): def register(fixed, moving, cnf):
...@@ -286,8 +334,12 @@ def register(fixed, moving, cnf): ...@@ -286,8 +334,12 @@ def register(fixed, moving, cnf):
# rotation -> scale -> translation -> affine -> nonlinear # rotation -> scale -> translation -> affine -> nonlinear
logger.info("Initialising transformation chain...") logger.info("Initialising transformation chain...")
chain = initialise_transformations(fixed, p) chain = initialise_transformations(fixed, p)
fixed.domain = fixed.domain[:0]
logger.info("Transformation chain has been initialised.") logger.info("Transformation chain has been initialised.")
# Set the first part of the chain
fixed.domain.chain.extend(chain[:-3])
# Generate output: initial alignment # Generate output: initial alignment
fixed.save(os.path.join( fixed.save(os.path.join(
p.general.outputdir, "fixed.timg"), overwrite=True) p.general.outputdir, "fixed.timg"), overwrite=True)
...@@ -298,9 +350,6 @@ def register(fixed, moving, cnf): ...@@ -298,9 +350,6 @@ def register(fixed, moving, cnf):
moving.snapshot(os.path.join( moving.snapshot(os.path.join(
p.general.outputdir, f"moving.{SNAPSHOT_EXT}"), overwrite=True) p.general.outputdir, f"moving.{SNAPSHOT_EXT}"), overwrite=True)
# Set the first part of the chain
fixed.domain.chain.extend(chain[:-2])
# Rotation search # Rotation search
if "rotation" in p.general.stages: if "rotation" in p.general.stages:
logger.info("Starting rotation search...") logger.info("Starting rotation search...")
...@@ -330,7 +379,7 @@ def register(fixed, moving, cnf): ...@@ -330,7 +379,7 @@ def register(fixed, moving, cnf):
logger.info("Rigid registration was skipped.") logger.info("Rigid registration was skipped.")
# Affine registration # Affine registration
fixed.domain.chain.append(chain[-2]) fixed.domain.chain.append(chain[-3])
if "affine" in p.general.stages: if "affine" in p.general.stages:
logger.info("Starting affine registration...") logger.info("Starting affine registration...")
affine2d(fixed, moving, p) affine2d(fixed, moving, p)
...@@ -345,7 +394,7 @@ def register(fixed, moving, cnf): ...@@ -345,7 +394,7 @@ def register(fixed, moving, cnf):
logger.info("Affine registration was skipped.") logger.info("Affine registration was skipped.")
# Non-linear registration # Non-linear registration
tx_nonlinear = chain[-1] tx_nonlinear = chain[-2]
tx_nonlinear.domain.chain = fixed.domain.chain[:] tx_nonlinear.domain.chain = fixed.domain.chain[:]
fixed.domain.chain.append(tx_nonlinear) fixed.domain.chain.append(tx_nonlinear)
if "nonlinear" in p.general.stages: if "nonlinear" in p.general.stages:
...@@ -361,6 +410,23 @@ def register(fixed, moving, cnf): ...@@ -361,6 +410,23 @@ def register(fixed, moving, cnf):
else: else:
logger.info("Non-linear registration was skipped.") logger.info("Non-linear registration was skipped.")
# Non-linear registration 2
tx_nonlinear2 = chain[-1]
tx_nonlinear2.domain.chain = fixed.domain.chain[:-1]
fixed.domain.chain.append(tx_nonlinear2)
if "nonlinear2" in p.general.stages:
logger.info("Starting the 2nd non-linear registration...")
diffreg2d(fixed, moving, p)
logger.info("Completed the 2nd non-linear registration.")
# Generate output
fixed.save(os.path.join(
p.general.outputdir, "fixed5_nonlinear2.timg"), overwrite=True)
moving.evaluate(fixed.domain).snapshot(os.path.join(
p.general.outputdir, f"moving5_nonlinear2.{SNAPSHOT_EXT}"),
overwrite=True)
else:
logger.info("The 2nd non-linear registration was skipped.")
# try: # try:
# os.remove(os.path.join(p.general.outdir, "histology.timg")) # os.remove(os.path.join(p.general.outdir, "histology.timg"))
# except: # except:
...@@ -552,16 +618,18 @@ def diffreg2d(fixed, moving, cnf): ...@@ -552,16 +618,18 @@ def diffreg2d(fixed, moving, cnf):
for i, (sc, sm) in enumerate(zip(q.scaling, q.smoothing)): for i, (sc, sm) in enumerate(zip(q.scaling, q.smoothing)):
logger.debug(f"Scale: {sc}, smoothing: {sm} px...") logger.debug(f"Scale: {sc}, smoothing: {sm} px...")
# Prepare images for the current iteration # Prepare images for the current iteration
fixed.resample(1 / sc, copy=False) fixed.resample(1 / sc, copy=False, update_chain=False)
moving.resample(1 / sc, copy=False) moving.resample(1 / sc, copy=False, update_chain=False)
fixed_smooth = fixed.smooth(sm, copy=True) fixed_smooth = fixed.smooth(sm, copy=True)
moving_smooth = moving.smooth(sm, copy=True) moving_smooth = moving.smooth(sm, copy=True)
# Prepare transformation to optimise # Prepare transformation to optimise
tx_nonlinear = fixed_smooth.domain.chain[-1] tx_nonlinear = fixed_smooth.domain.chain[-1]
tx_nonlinear = tx_nonlinear.regrid(fixed_smooth.domain[:-2])
fixed_smooth.domain.chain[-1] = tx_nonlinear
# Set cost and regulariser # Set cost and regulariser
cost = CostMIND(moving_smooth, fixed_smooth, sigma=float(q.sigma), cost = CostMIND(moving_smooth, fixed_smooth, sigma=float(q.sigma),
truncate=float(q.truncate), kernel=MK_FULL, truncate=float(q.truncate), kernel=MK_FULL,
ignore_masked_edges=True) ignore_masked_edges=True, ignore_masked_regions=True)
regularisation = DiffusionRegulariser( regularisation = DiffusionRegulariser(
tx_nonlinear, weight=float(q.regweight)) tx_nonlinear, weight=float(q.regweight))
# Optimise the non-linear transformation # Optimise the non-linear transformation
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment