Commit 0a6990dc authored by ihuszar's avatar ihuszar
Browse files

Reconfigured the transformation package.

parent 186e0df1
......@@ -3,7 +3,7 @@ Tensor Image Registration Library
## Installation
1. Download the TIRL source codeL
1. Download the TIRL source code.
```
mkdir -p ~/Applications/
......
......@@ -14,9 +14,11 @@ cython_modules = cythonize(module_carraycalc,
compiler_directives={'embedsignature': True})
ext_modules = cython_modules
dependencies = ["attrdict", "cython", "dill", "ghalton", "joblib", "matplotlib",
"nibabel", "nlopt", "numpy", "openslide-python", "psutil",
"pillow", "pygments", "scipy", "scikit-image", "tifffile"]
# TODO: Verify the name of the imagecodecs library!
dependencies = ["attrdict", "cython", "dill", "ghalton", "imagecodecs",
"joblib", "matplotlib", "nibabel", "nlopt", "numpy",
"openslide-python", "psutil", "pillow", "pygments", "scipy",
"scikit-image", "tifffile"]
setup(name="tirl",
version=1.0,
......@@ -33,7 +35,11 @@ setup(name="tirl",
"tirl/regularisers",
"tirl/tirlvision",
"tirl/tirlvision/optimisers",
"tirl/transformations"],
"tirl/transformations",
"tirl/transformations/basic",
"tirl/transformations/auxiliary",
"tirl/transformations/linear",
"tirl/transformations/nonlinear"],
scripts=["protocols/tirl_affine2d",
"protocols/tirl_affine2d3d",
"protocols/tirl_affine3d",
......
# Base class transformations
import unittest
class ImportTransformations(unittest.TestCase):
def test_base_class(self):
from tirl.transformations import Transformation
from tirl.transformations import TransformationGroup
def test_base_derivatives(self):
from tirl.transformations.basic import TxIdentity
from tirl.transformations.basic import TxEmbed
from tirl.transformations.basic import TxReduce
def test_linear_tx_package(self):
from tirl.transformations.linear import TxLinear
from tirl.transformations.linear import TxTranslation
from tirl.transformations.linear import TxRotation
from tirl.transformations.linear import TxRotation2D
from tirl.transformations.linear import TxRotation3D
from tirl.transformations.linear import TxRotationEulerAngles
from tirl.transformations.linear import TxRotationAxisAngle
from tirl.transformations.linear import TxRotationQuaternion
from tirl.transformations.linear import TxScale
from tirl.transformations.linear import TxShear3D
from tirl.transformations.linear import TxOrthogonalProjection
from tirl.transformations.linear import TxPerspectiveProjection
from tirl.transformations.linear import TxRigid
from tirl.transformations.linear import TxAffine
from tirl.transformations.linear import TxParametricAffine
from tirl.transformations.linear import TxFLIRTAffine
def test_auxiliary_tx_package(self):
from tirl.transformations.auxiliary import TxAdapter
from tirl.transformations.auxiliary import TxRotationField
from tirl.transformations.auxiliary import TxScaleField
def test_nonlinear_package(self):
from tirl.transformations.nonlinear import TxNonLinear
from tirl.transformations.nonlinear import TxPolynomial
from tirl.transformations.nonlinear import TxDisplacementField
from tirl.transformations.nonlinear import TxRbfDisplacementField
if __name__ == "__main__":
unittest.main()
......@@ -38,6 +38,7 @@ def load(fname):
def expose_package_contents(baseclass, pkg, path, globals=None):
from importlib import import_module
from inspect import isclass
from glob import glob
module_name_pattern = "{}/*.py".format(path[0])
modules = [os.path.split(m)[-1] for m in glob(module_name_pattern)]
......@@ -49,12 +50,11 @@ def expose_package_contents(baseclass, pkg, path, globals=None):
for class_name in dir(imported_module):
if not class_name.startswith("__"):
c = getattr(imported_module, class_name)
if type(c) is type:
if issubclass(c, baseclass):
if globals:
globals.update({class_name: c})
else:
globals().update({class_name: c})
if isclass(c) and issubclass(c, baseclass):
if globals:
globals.update({class_name: c})
else:
globals().update({class_name: c})
# from tirl import transformations
......
This diff is collapsed.
......@@ -42,7 +42,7 @@ from tirl import exceptions as te
from tirl.tirlobject import TIRLObject
from tirl.parallelism import map_chunk_coordinates
from tirl.parallelism import copy_chunk_coordinates
from tirl.transformations.basic import Transformation, TransformationGroup
from tirl.transformations import Transformation, TransformationGroup
from tirl.transformations.basic import TxIdentity, TxEmbed, TxReduce
from tirl.transformations.linear import TxTranslation, TxScale
from tirl.transformations.auxiliary import TxRotationField
......
......@@ -8,7 +8,7 @@
import os
# Temporary working directory (use absolute path!)
TWD = "/mnt/nvme/temp"
TWD = "/Users/inhuszar/temp"
if not os.path.isdir(TWD):
os.makedirs(TWD)
......@@ -126,4 +126,4 @@ HALTON_SEED = 1
# Visualisation #
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
ENABLE_VISUALISATION = True
MPL_BACKEND = "tkagg"
MPL_BACKEND = "macosx"
......@@ -34,6 +34,7 @@ from scipy.ndimage.filters import gaussian_filter
from time import time
from tirl.utils import change_memmap_dtype
from tirl.utils import isreal
from tirl.domain import Domain
from tirl.buffer import Buffer
from tirl import settings as ts
......@@ -70,9 +71,363 @@ PREVIEW_MODES = ("rgb", "hsv", "composite", "quiver")
# operation-defining methods from the first-degree subclass and implement a
# from<subclass_1> constructor, respectively.
# TODO: What is the best way to integrate ResolutionManager with TField?
# IMPLEMENTATION
class ResolutionManager(TIRLObject):
"""
Plug-in container that manages the low- and high-resolution field data of a
TField instance.
"""
# TODO: Add _load/dump methods.
def __init__(self, source):
"""
Initialisation of ResolutionManager. The field data that is available
from the host TField at the time of initialisation will be considered
high-resolution data, from which other samplings of the data will be
derived.
:param source: template field
:type source: TField
"""
# Call superclass initialisation
super(ResolutionManager, self).__init__()
# Set high-resolution buffers
if not hasattr(source, "__tfield__"):
raise TypeError("Expected TField-like object as source, got {}."
.format(source.__class__.__name__))
self._data = source._data # buffer
self._dtype = source.dtype
self._domain = source.domain
self._interpolator = source.interpolator
self._order = source.order
self._storage = source.storage
self._kwargs = source.kwargs
self._srcname = source.name
self._taxes = source.taxes
self._tshape = source.tshape
self._vaxes = source.vaxes
self._vshape = source.vshape
def copy(self):
template = self.template.copy()
return ResolutionManager(template)
def get(self, scale, *scales, update_chain=True,
presmooth=ts.TIMAGE_PRESMOOTH):
"""
Creates resampled version of the template TField, which has the same
physical coordinates and shares transformations with the original
instance. Note that non-linear transformations may still be replaced if
update_chain=True (default).
:param scale:
Global scaling factor (relative to the high-resolution data).
If further values are specified, this is interpreted as the
scaling factor for the first axis.
:type scale: int
:param scales:
Scaling factors along higher dimensions.
:type scales: tuple[int]
:param update_chain:
If True (default), non-linear transformations that are linked with
the TField domain will remain linked, which implies that these
will be resampled via their regrid() method. If False, all
non-linear transformations will be detached and left unchanged.
In either case the identity of all transformation objects is
preserved (no copying will take place).
:type update_chain: bool
:param presmooth:
Applies Gaussian presmoothing before downsampling (no smoothing is
performed for upsampling). This ensures that all domain points at
the lower resolution have proper support in the higher-resolution
image, therefore all original information is (at least indirectly)
represented.
:type presmooth: bool
:returns: resampled TField.
:rtype: TField
"""
new_domain, data = \
self._get_scaled_data(scale, *scales, presmooth=presmooth)
# Update linked transformations and reattach the transformation chain
new_domain.transformations = self.domain.transformations
if update_chain:
new_domain = self._update_domain_transformations(
old_domain=self.domain, new_domain=new_domain)
# Update the name of the image
scales = (scale,) + tuple(scales) if scales else (scale,)
name = self.imgname + "_res{}".format("x".join(scales))
# Create a new resampled Field instance
return TImage(
data, tensor_axes=self.taxes, dtype=self.dtype, mask=mask,
name=name, domain=new_domain, order=self.order,
interpolator=self.interpolator.copy(), storage=self.storage,
**self.kwargs.copy())
def _get_scaled_data(self, scale, *scales, presmooth=ts.TIMAGE_PRESMOOTH):
"""
Generates a scaled version of the high-resolution field data buffer
by resampling it on a new domain. Calculates scale offset compensation
for the new domain to preserve the physical coordinates of the template
TField.
:param scale:
Global scaling factor (relative to the high-resolution data).
If further values are specified, this is interpreted as the
scaling factor for the first axis.
:type scale: int
:param scales:
Scaling factors along higher dimensions.
:type scales: tuple[int]
:param presmooth:
Applies Gaussian presmoothing before downsampling (no smoothing is
performed for upsampling). This ensures that all domain points at
the lower resolution have proper support in the higher-resolution
image, therefore all original information is (at least indirectly)
represented.
:type presmooth: bool
:returns: domain, resampled field data buffer
:rtype: tuple[Domain, Buffer, Union[Buffer, NoneType]]
"""
scales = (scale,) + tuple(scales) if scales else (scale,)
if not all(isreal(f) and f > 0 for f in scales):
raise ValueError("TField scaling factors must be numeric and "
"positive.")
new_shape = tuple(round(f * d) for f, d in zip(scales, self.vshape))
# Restore high-resolution versions of the data, the domain,
# the interpolator and set the scale to precisely 1, if all scale
# factors are close to 1.
if new_shape == self.vshape:
return self._domain, self._data
# Create target domain with appropriate scale-offset transformations to
# compensate for the sampling difference. This will ensure that the
# physical coordinates remain unchanged.
offset_tx = TxScale(*tuple(1. / s for s in scales))
offset = [offset_tx] + self.domain.offset
# Purge domain kwargs
ndkwargs = {k: v for (k, v) in self.domain.kwargs.items()
if k not in self.domain.RESERVED_KWARGS}
new_domain = Domain(
new_shape, offset=offset, storage=self.domain.storage,
dtype=self.domain.dtype,
instance_mem_limit=self.domain.instance_mem_limit,
n_threads=self.domain.threads, **ndkwargs)
# Resample the field data
imgdata = TField.fromarray(
arr=self._data.data, tensor_axes=self.taxes, copy=False,
domain=self.domain[:0], order=None, dtype=None,
interpolator=self._interpolator.__class__, name=None,
storage=self._storage, **self._kwargs)
data = self._resample(imgdata, domain=new_domain, presmooth=presmooth)
return new_domain, data
def _resample(self, source, domain, presmooth=True):
"""
Resamples buffer with optional presmoothing.
"""
# Convert input data to a transformationless TField instance
# Transformations are removed to confine the evaluation to the voxels.
img = TField(source, domain=source.domain[:0])
# Presmoothing creates a proper Gaussian support for each point of the
# new domain in the old domain, so in theory all original information is
# represented in the downsampled data.
if presmooth:
scales = np.divide(domain.shape, source.domain.shape)
coeff = 1. / (2. * np.sqrt(2 * np.log(2)))
t_sigmas = (0,) * len(self.taxes)
v_sigmas = tuple(1 / s * coeff if s < 1 else 0 for s in scales)
sigmas = t_sigmas + v_sigmas # operator uses TENSOR_MAJOR order
kernelsize = ts.TIMAGE_PRESMOOTH_KERNELSIZE_NSIGMA
neighbourhood = max(ceil(kernelsize * max(sigmas)), 1)
smop = SpatialOperator(
gaussian_filter, radius=neighbourhood, name="smooth",
sigma=sigmas, truncate=kernelsize)
img = smop(img)
# Resample image data on a new voxel domain (no transformations)
img = img.evaluate(domain[:0])
assert isinstance(img, TField)
# Return buffer
return img._data
@staticmethod
def _update_domain_transformations(old_domain, new_domain):
"""
Regrids dynamically linked non-linear transformations, so that they can
remain linked to the TField after changing the shape of the TField
domain during resampling.
"""
new_domain = new_domain[:0]
new_domain.transformations = old_domain.transformations
for i, tx in enumerate(old_domain.transformations):
if hasattr(tx, "domain"):
# If the transformation was dynamically linked to
# the current domain, re-grid that transformation.
if tx.domain == old_domain[:i]:
# Erase transformations on the tx domain
tx.metaparameters["domain"] = old_domain[:0]
# Resample among voxel domains
new_tx = tx.regrid(new_domain[:0])
# Add back all new antecedent transformations
new_tx.metaparameters["domain"] = new_domain[:i]
# Add the current transformation itself to the new domain
new_domain.transformations[i] = new_tx
return new_domain
@property
def template(self):
"""
Returns the high-resolution image template that is used by the current
Resolution Manager instance to generate resampled field data.
:returns: high-resolution TField template
:rtype: TField
"""
return TImage(
source=self.data, tensor_axes=self.taxes, dtype=self.dtype,
mask=self.mask, name=self.imgname, domain=self.domain,
order=self.order, interpolator=self.interpolator,
storage=self.storage, header=self.header, **self.kwargs)
@property
def dtype(self):
return self._dtype
@property
def order(self):
return self._order
@property
def storage(self):
return self._storage
@property
def header(self):
return self._header
@property
def kwargs(self):
return self._kwargs
@property
def imgname(self):
return self._imgname
@property
def vshape(self):
return self._vshape
@property
def vaxes(self):
return self._vaxes
@property
def tshape(self):
return self._tshape
@property
def taxes(self):
return self._taxes
@property
def domain(self):
return self._domain
@domain.setter
def domain(self, d):
"""
Domain setter method. For advanced users only. The input must be a
Domain instance that matches the voxel shape of the TField data at the
current resolution.
"""
if isinstance(d, Domain):
if self.domain.shape == d.shape:
self._domain = d
else:
raise te.DomainError(
"Cannot assign domain with shape {} to a {} with voxel "
"shape {}.".format(d.shape, self.__class__.__name__,
self.vshape))
else:
raise TypeError("Expected a Domain instance, got {}."
.format(d.__class__.__name__))
@property
def data(self):
"""
High-resolution TField data.
:returns: image data
:rtype: np.ndarray
"""
return self._data.data
@property
def interpolator(self):
return self._interpolator
@interpolator.setter
def interpolator(self, ip):
"""
Sets the interpolator that is used to resample the high-resolution
field data.
"""
# Create interpolator based on the Domain type (compact vs. non-compact)
if ip is None:
# Choose default interpolator according to the domain type
if self.domain.flags["iscompact"]:
ip = locate(ts.DEFAULT_COMPACT_INTERPOLATOR)
else:
ip = locate(ts.DEFAULT_NONCOMPACT_INTERPOLATOR)
# Create interpolator instance from class specification
if issubclass(type(ip), type) and issubclass(ip, Interpolator):
ip = ip(source=self.data, n_threads=-1)
ip.tensor_axes = self.taxes
# Adapt existing interpolator instance
elif isinstance(ip, Interpolator):
# If interpolator is given, set the source to the TField data
if ip.tensor_axes == self.taxes:
ip._values = self.data
else:
raise ValueError(
"The tensor axis specification of the interpolator {} "
"does not match {} with tensor axis/axes {}"
.format(ip.tensor_axes, self.__class__.__name__,
self.taxes))
else:
raise TypeError("Expected Interpolator type, got {}"
.format(ip.__class__.__name__))
# Update interpolator
self._interpolator = ip
class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
# Set the array priority higher than that of np.matrix with 10.0
......@@ -1625,7 +1980,7 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
the advantage that the output shape is defined explicitly,
avoiding rounding errors.
2. Specifying down/upscaling factors (below and above 1,
respectively) for each dimension of the TImage, or one scaling
respectively) for each dimension of the TField, or one scaling
factor for all dimensions. The main advantage of this method is
its simplicity, but may result in slightly different shapes than
expected due to rounding. Definitions of scale are relative to
......@@ -1640,11 +1995,11 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
Additional arguments that specify the scale factor for consecutive
voxel dimensions of the TField. If more than one scaling factor is
specified, the number of scaling factors must match the number of
spatial dimesions of the TField.
spatial dimensions of the TField.
:notes:
If the TImage does not have high-resolution image (and mask) data,
the requested scaling is defined relative to the current domain, and
If the TFIeld does not have high-resolution field data, the
requested scaling is defined relative to the current domain, and
data on the current domain will be set as the high-resolution
reference for further resampling operations, until it is released
by the 'reset_scale' method.
......@@ -1697,7 +2052,7 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
else:
raise te.ArgumentError(
"The number of scaling factors must match the number "
"of spatial (voxel) dimensions of the TImage ({})."
"of spatial (voxel) dimensions of the TField ({})."
.format(self.vdim))
else:
scales = (target,) * self.vdim
......@@ -1726,7 +2081,7 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
"Invalid target specification for resampling.")
# Now that the new domain is ready and the scaling factors are known,
# resample the image (and the mask) data.
# resample the field data.
# Nothing to do
if np.allclose(relative_scales, 1):
......
......@@ -103,7 +103,7 @@ class ResolutionManager(TIRLObject):
self._order = source.order
self._storage = source.storage
self._kwargs = source.kwargs
self._imgname = source.name
self._srcname = source.name
self._taxes = source.taxes
self._tshape = source.tshape
self._vaxes = source.vaxes
......@@ -229,7 +229,7 @@ class ResolutionManager(TIRLObject):
data = self._resample(imgdata, domain=new_domain, presmooth=presmooth)
if self._mask is not None:
maskdata = TField.fromarray(
arr=self._data.mask, tensor_axes=(), copy=False,
arr=self._mask.data, tensor_axes=(), copy=False,
domain=self.domain[:0], order=None, dtype=None,
interpolator=self._interpolator.__class__, name=None,
storage=self._storage, **self._kwargs)
......@@ -335,7 +335,7 @@ class ResolutionManager(TIRLObject):
@property
def imgname(self):
return self._imgname
return self._srcname
@property
def vshape(self):
......@@ -1603,7 +1603,14 @@ class TImage(TField):
obj._mask = self._mask.copy()
else:
obj._mask = None
obj.resmgr = ResolutionManager(obj) # copies the high-res template image!
# If the current instance has a higher-resolution image attached,
# pass the highres image by reference. No need to copy, because the
# current instance is not supposed to modify the high-resolution image.
# While there might be exceptions to this, this is a design choice to
# make copy() calls as fast as reasonably possible, as copy calls are
# much more common than the few exceptions mentioned before.
if self.resmgr is not None:
obj.resmgr = ResolutionManager(self.highres)
return obj
......
# from tirl import expose_package_contents
# from tirl.transformations.basic import Transformation
#
# # Expose all transformation types at the module level
# # Note: this routine imports the Transformation base class, and every
# # subclass thereof from all modules within the "transformations" package.
#
# expose_package_contents(baseclass=Transformation, pkg="tirl.transformations",
# path=__path__)
from tirl.transformations.transformation import Transformation
from tirl.transformations.transformation import TransformationGroup
# Note: these transformations are intended only for internal use!
from tirl import expose_package_contents
from tirl.transformations.transformation import Transformation
# The default libraries must be imported first.
from tirl.transformations.auxiliary.adapter import TxAdapter
from tirl.transformations.auxiliary.rotfield import TxRotationField
from tirl.transformations.auxiliary.scalefield import TxScaleField