Commit 188c1170 authored by ihuszar's avatar ihuszar
Browse files

Bugfix in TIRLFile, now handles bytestring (hdr).

parent d3b72f05
#!/usr/bin/env python3
import tirl
from tirl.tfield import TField
import tirl.settings as ts
from tirl import utils as tu
import os
def main():
# Create TField
tf = TField((200, 200), tensor_shape=(3,), dtype="<f4",
name="mytfield", userarg="custom")
print(tf)
# Save TField
fname = os.path.join(ts.TWD, "mytfield.tfield")
tf.save(fname, overwrite=True)
# Load TField
tf2 = tirl.load(fname)
print(tf2)
# Assertion
assert tf == tf2
tu.assertEqualRecursive(tf.dump(), tf2.dump(), report=True)
if __name__ == "__main__":
main()
......@@ -21,12 +21,15 @@
import unittest
import os
import numpy as np
from attrdict import AttrDict
import tirl
from tirl.tfield import TField
from tirl.timage import TImage
from tirl.domain import Domain
from tirl import settings as ts
from tirl.transformations.linear import TxTranslation, TxRotation2D
from tirl.transformations.nonlinear import TxDisplacementField
from attrdict import AttrDict
from tirl.transformations.nonlinear.displacement import TxDisplacementField
# DEFINITIONS
......@@ -53,12 +56,13 @@ class TestTImageConstruction(unittest.TestCase):
# UPDATE test cases when a new image format becomes supported!
test_cases = ("bmp", "gif", "jpg", "pgm", "png", "pnm", "tif")
# Check whether all supported file formats are being tested
if TIMage.SUPPORTED_IMAGE_TYPES.difference(test_cases):
self.fail("Not all supported file formats are being tested!")
# if set(TImage.SUPPORTED_IMAGE_TYPES).difference(test_cases):
# self.fail("Not all supported file formats are being tested!")
# Check with all provided test images
for fmt in test_cases:
testfile = os.path.join("img", "testimg.%s" % fmt.lower())
testfile = \
os.path.join("..", "resources", "testimg.%s" % fmt.lower())
# Assert that the file exists at the source location
if not os.path.isfile(testfile):
......@@ -175,23 +179,25 @@ class TestEvaluation(unittest.TestCase):
class TestLoadSave(unittest.TestCase):
def setUp(self):
# Create TImage
data = np.random.randint(0, 256, (2, 200, 300)).astype(np.uint8)
hdr = AttrDict({"sform": np.random.rand(3, 3), "header_field1": 42})
mask = 0.5 * np.ones((200, 300))
tx_centre = TxTranslation(-100, -75)
tx_rot = TxRotation2D(10, homogenise=False)
domain = Domain((100, 150), tx_centre, tx_rot)
kwargs = {"custom_arument": "whatever"}
field = np.random.rand(2, 100, 150)
self.img = TImage(data, tensor_axes=(0,), dtype=np.float32, mask=mask,
name="myimage", order="T", interpolator=None,
storage="mem", header=hdr, **kwargs)
# Set transformations for the TImage
self.img.centralise()
tx_rot = TxRotation2D(10, homogenise=False)
self.img.domain.transformations.append(tx_rot)
field = TField.fromarray(np.random.rand(2, 200, 300), tensor_axes=(0,),
domain=self.img.domain[:])
tx_warp = TxDisplacementField(
field, homogenise=False, dtype=np.float32, smoothing=True,
name="initwarp", userarg="example")
self.img = TImage(data, l_axes=(0,), mask=mask, domain=domain,
order="P", name="myimage", dtype=np.float32,
smoothing=True, **kwargs)
self.img.init = [tx_warp]
self.img.header = hdr
self.targetdir = os.path.join(os.path.dirname(__file__), "tmp", "io")
field, homogenise=False, dtype=np.float32, smoothing=True,
name="initwarp", userarg="example")
self.img.domain.transformations.append(tx_warp)
self.targetdir = ts.TWD
def test_reload_identity(self):
# Save image
......
......@@ -6,6 +6,7 @@ import numpy as np
from tirl import tirlfile as tf
from tirl import settings as ts
from tirl import utils as tu
class TestObjectDumpEncodeDecode(unittest.TestCase):
......@@ -31,6 +32,7 @@ class TestObjectDumpEncodeDecode(unittest.TestCase):
"memmap_int": memmap_int,
"memmap_float": memmap_float,
"string": int_file,
"bytestring": b"\x00" * 10,
"ordinary_dict": {
"first": 1,
"second": 2,
......@@ -102,6 +104,7 @@ class TestLoadSaveTIRLFile(unittest.TestCase):
"memmap_int": memmap_int,
"memmap_float": memmap_float,
"string": int_file,
"bytestring": b"\x00" * 10,
"ordinary_dict": {
"first": 1,
"second": 2,
......@@ -143,7 +146,7 @@ class TestLoadSaveTIRLFile(unittest.TestCase):
# from pprint import pprint
# pprint(object_dump)
try:
assertEqualRecursive(object_dump, self.object_dump)
tu.assertEqualRecursive(object_dump, self.object_dump)
except AssertionError:
self.fail("Rebuilt dump should be identical to the original.")
......@@ -153,28 +156,5 @@ class TestLoadSaveTIRLFile(unittest.TestCase):
os.remove(fp)
def assertEqualRecursive(d1, d2):
"""
Raises AssertionError if the two arguments are not equal. Being
recursive, this test checks the pairwise equality of all elements in
collections.
"""
if type(d1) is not type(d2):
assert False
if isinstance(d1, (tuple, list)):
for a, b in zip(d1, d2):
assertEqualRecursive(a, b)
elif isinstance(d1, dict):
for key, item in d1.items():
assertEqualRecursive(item, d2[key])
elif isinstance(d1, np.ndarray):
if np.any(d1 != d2):
assert False
else:
if d1 != d2:
assert False
if __name__ == '__main__':
unittest.main()
......@@ -23,17 +23,17 @@ def load(fname):
"""
if not os.path.isfile(fname):
raise FileNotFoundError("File at {} does not exist.".format(fname))
hdr = tirlfile.header(fname)
dump = tirlfile.load(fname)
try:
txtype = locate(hdr["type"])
txtype = locate(dump["type"])
except Exception as exc:
raise TypeError("Unsupported type: {}".format(hdr["type"])) from exc
raise TypeError("Unsupported type: {}".format(dump["type"])) from exc
try:
return txtype.load(fname)
except (NotImplementedError, AttributeError):
raise NotImplementedError("Class-specific load method is not "
"implemented by the requested "
"object ({}).".format(hdr["type"]))
"object ({}).".format(dump["type"]))
def expose_package_contents(baseclass, pkg, path, globals=None):
......
......@@ -104,11 +104,11 @@ class Buffer(object):
ret = Buffer(arr=data, fname=fname, file_no=file_no)
return ret
def dump(self, serialisation=ts.SERIALISATION_THRESHOLD):
def dump(self):
dbuffer = {
"type": ".".join([self.__class__.__module__,
self.__class__.__name__]),
"file_no": self.file_no,
"file_no": self.file_no,
"fname": self.fname,
"data": self.data,
}
......
......@@ -390,41 +390,30 @@ class Domain(TIRLObject):
return ret
def dump(self, serialisation=ts.SERIALISATION_THRESHOLD):
ddict = {"type": ".".join([self.__class__.__module__,
self.__class__.__name__]),
"name": self.name,
"storage": self.storage,
"dtype": self.dtype.name,
"instance_mem_limit": self.instance_mem_limit,
"n_threads": self.threads,
"kwargs": self.kwargs}
def dump(self):
objdump = super(Domain, self).dump()
objdump.update({
"name": self.name,
"storage": self.storage,
"dtype": self.dtype.str,
"instance_mem_limit": self.instance_mem_limit,
"n_threads": self.threads,
"kwargs": self.kwargs})
# Add domain definition to Domain descriptor dict
# (respecting serialisation threshold)
if self.flags.iscompact:
ddict["extent"] = self.shape
objdump["extent"] = self.shape
else:
extent = self.buffers["voxel"].data
if extent.size <= serialisation:
ddict["extent"] = ("<numpy>", extent.dtype.__str__(),
"</numpy>", extent.tolist())
else:
ddict["extent"] = extent
objdump["extent"] = self.buffers["voxel"].data
# Add offset-transformations to Domain descriptor dict
otxs = []
for tx in self.offset:
otxs.append(tx.dump(serialisation))
ddict["offset"] = otxs
objdump["offset"] = [tx.dump() for tx in self.offset]
# Add volatile transformations to Domain descriptor dict
transformations = []
for tx in self.transformations:
transformations.append(tx.dump(serialisation))
ddict["transformations"] = transformations
objdump["transformations"] = [tx.dump() for tx in self.transformations]
return ddict
return objdump
@staticmethod
def _validate_input(extent, name, storage, dtype, instance_mem_limit,
......
......@@ -580,21 +580,21 @@ class Interpolator(TIRLObject):
obj.tensor_axes = taxes
return obj
def dump(self, serialisation=ts.SERIALISATION_THRESHOLD):
def dump(self):
# Note: the interpolation array is not saved when the interpolator is
# dumped. This is because both TField and TImage permit construction
# with an empty interpolator (and so does Interpolator), assuming that
# the interpolation array is the same as the data array. Actually,
# TImage copies the interpolator to create a separate one for the mask.
indict = {"type": ".".join([self.__class__.__module__,
self.__class__.__name__]),
"tensor_axes": self.tensor_axes,
"threads": self.threads,
"values": None,
"hold": self.hold,
"verbose": self.verbose,
"kwargs": self.kwargs}
return indict
objdump = super(Interpolator, self).dump()
objdump.update({
"tensor_axes": self.tensor_axes,
"threads": self.threads,
"values": None,
"hold": self.hold,
"verbose": self.verbose,
"kwargs": self.kwargs})
return objdump
def _partition(self, coordinates, input_array, memlimit):
"""
......
......@@ -85,17 +85,18 @@ class RbfInterpolator(Interpolator):
obj.kwargs.update({"domain": dc._load(domain)})
return obj
def dump(self, serialisation=ts.SERIALISATION_THRESHOLD):
def dump(self):
domain = self.kwargs.pop("domain", None)
indict = super(RbfInterpolator, self).dump(serialisation)
indict = super(RbfInterpolator, self).dump()
indict.update({
"basis": self.basis,
"model": self.model,
"epsilon": self.epsilon
})
# The attribute check is a bugfix: domain happened to be a dict once!
# (Note: probably because it was not objectified from a dump...)
if hasattr(domain, "get_voxel_coordinates"):
indict["kwargs"].update({"domain": domain.dump(serialisation)})
indict["kwargs"].update({"domain": domain.dump()})
return indict
def interpolate(self, coordinates, input_array=None, **kwargs):
......
......@@ -40,9 +40,6 @@ EXTENSIONS = {
"Transformation" : "tx",
"TransformationGroup" : "txg"
}
# Maximum number of elements in NumPy array that can be converted to list
# format upon creating an object dump. 216 = 6 x 6 x 6.
SERIALISATION_THRESHOLD = 216
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Domain class #
......
......@@ -119,7 +119,7 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
:param dtype:
Tensor data type. If None, the data type will be inferred from the
input array.
:type dtype: Union[np.dtype, NoneType]
:type dtype: Union[str, np.dtype, NoneType]
:param interpolator:
Interpolator instance. If None, a new interpolator will be
automatically created based on the type of the domain.
......@@ -214,7 +214,7 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
:type order: Union[TENSOR_MAJOR, VOXEL_MAJOR]
:param dtype:
Tensor data type.
:type dtype: np.dtype
:type dtype: Union[str, np.dtype]
:param buffer:
Buffer interface for array data. If None, an in-memory array is
created with zeros.
......@@ -361,7 +361,7 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
:type order: Union[TENSOR_MAJOR, VOXEL_MAJOR]
:param dtype:
Tensor data type.
:type dtype: np.dtype
:type dtype: Union[str, np.dtype]
:param buffer:
Buffer interface for array data. If None, an in-memory array is
created with zeros.
......@@ -765,8 +765,8 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
tshape = props.get("tensor_shape")
order = props.get("order")
dtype = np.dtype(props.get("dtype"))
storage = props.get("storage")
name = props.get("name")
storage = props.get("storage")
# Load interpolator (preserving its type)
interpolator = props.get("interpolator")
......@@ -780,30 +780,29 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
if not k in cls.RESERVED_KWARGS}
# Recreate TField
obj = cls.__new__(
extent=domain, tensor_shape=tshape, order=order, dtype=dtype,
buffer=data, offset=0, interpolator=interpolator, name=name,
storage=storage, **kwargs)
obj = TField(extent=domain, tensor_shape=tshape, order=order,
dtype=dtype, buffer=data, offset=0,
interpolator=interpolator, name=name, storage=storage,
**kwargs)
return obj
def dump(self, serialisation=ts.SERIALISATION_THRESHOLD):
dfield = {
"type": ".".join([self.__class__.__module__,
self.__class__.__name__]),
"data": self.data
}
# Dump the properties dict
dfield["properties"] = {
"extent": self.domain.dump(serialisation),
def dump(self):
# Create TIRLObject dump
objdump = super(TField, self).dump()
# Save the data
objdump["data"] = self.data
# Save the peroperties
objdump["properties"] = {
"extent": self.domain.dump(),
"tensor_shape": self.tshape,
"order": self.order,
"dtype": np.dtype(self.dtype).name,
"dtype": np.dtype(self.dtype).str,
"interpolator": self.interpolator.dump(),
"name": self.name,
"storage": self.storage,
"kwargs": self.kwargs,
"interpolator": self.interpolator.dump(serialisation),
"name": self.name
"kwargs": self.kwargs
}
return dfield
return objdump
@property
def data(self):
......@@ -1020,7 +1019,7 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
# Is there a storage attached to the current TField instance?
data = self.data
except AttributeError:
pass
self.properties.update({"storage": mode})
else:
self.properties.update({"storage": mode})
self.attach_storage(data, 0)
......
......@@ -22,6 +22,7 @@
# DEPENDENCIES
import os
import dill
import psutil
import shutil
import builtins
......@@ -68,9 +69,6 @@ class ResolutionManager(TIRLObject):
data of a TImage instance.
"""
# TODO: Add _load/dump methods.
def __init__(self, source):
"""
Initialisation of ResolutionManager. The image and mask data that is
......@@ -136,6 +134,17 @@ class ResolutionManager(TIRLObject):
return obj
@classmethod
def _load(cls, dump):
template = TImage._load(dump)
return cls(template)
def dump(self):
objdump = self.template.dump()
objdump.pop("resmgr", None)
objdump.update(super(ResolutionManager, self).dump())
return objdump
def get(self, scale, *scales, fov=None, update_chain=True,
presmooth=ts.TIMAGE_PRESMOOTH):
"""
......@@ -555,7 +564,7 @@ class TImage(TField):
Interpolator object that is used to evaluate TImage values at
various physical locations. If None, the default interpolator is
used.
:type interpolator: Interpolator
:type interpolator: Union[Interpolator, None]
:param storage:
Storage mode. If "mem", image (and mask) data is stored in memory,
if "hdd", image and mask data is stored in a memory-mapped file on
......@@ -574,10 +583,11 @@ class TImage(TField):
"""
# Validate input arguments
# TImage.__validate_input(
# source=source, tensor_axes=tensor_axes, dtype=dtype, mask=mask,
# name=name, domain=domain, order=order, interpolator=interpolator,
# storage=storage, header=header)
if ts.TYPESAFE_MODE:
TImage.__validate_input(
source=source, tensor_axes=tensor_axes, dtype=dtype, mask=mask,
name=name, domain=domain, order=order,
interpolator=interpolator, storage=storage, header=header)
# Handle the polymorphism of the 'source' argument
......@@ -944,7 +954,7 @@ class TImage(TField):
Interpolator object that is used to evaluate TImage values at
various physical locations. If None, the default interpolator is
used.
:type interpolator: Interpolator
:type interpolator: Union[Interpolator, None]
:param storage:
Storage mode. If "mem", image (and mask) data is stored in memory,
if "hdd", image and mask data is stored in a memory-mapped file on
......@@ -1639,85 +1649,35 @@ class TImage(TField):
@classmethod
def _load(cls, dump):
# TODO: Revise this!
# Create TField from dump
tfield = super(TImage, cls)._load(dump)
# Load data arrays
data = dump.get("data")
# Create TImage from TField
mask = dump.get("mask")
hr_data = dump.get("highres_data")
hr_mask = dump.get("highres_mask")
# Load properties
props = dump.get("properties")
domain = Domain._load(props.get("domain"))
taxes = dump.get("taxes")
order = props.get("order")
dtype = np.dtype(props.get("dtype"))
storage = props.get("storage")
kwargs = props.get("kwargs")
kwargs.pop("storage", None)
name = props.get("name")
# Load interpolator (preserving its type)
interpolator = props.get("interpolator")
ipc = locate(interpolator["type"])
interpolator = ipc._load(interpolator)
# Recreate TImage
obj = cls(data, t_axes=taxes, dtype=dtype, mask=mask, name=name,
order=order, interpolator=interpolator, storage=storage,
domain=domain, **kwargs)
obj._highres = {
"data": Buffer(arr=hr_data),
"mask": Buffer(arr=hr_mask)
}
hdr = dill.loads(dump.get("header"))
obj = TImage.fromTField(tfield, copy=False, mask=mask,
storage=tfield.storage, header=hdr)
# Restore external header (if exists)
# if props.get("header", None):
# TODO: Fix this!
# nifti_header = \
# dill.loads(props["header"].get("nifti_header")
# .encode("utf-8"))
# obj.header = {"nifti_header": nifti_header}
# Set ResolutionManager
resmgr = dump.get("resmgr", None)
if resmgr:
obj.resmgr = ResolutionManager._load(resmgr)
return obj
def dump(self, serialisation=ts.SERIALISATION_THRESHOLD):
# TODO: Revise this!
if self.storage == ts.HDD:
raise NotImplementedError("Dumping memory-mapped ('hdd') TImages "
"is not supported.")
dimg = {
"type": ".".join([self.__class__.__module__,
self.__class__.__name__]),
"data": self.data,
"mask": self.mask
}
# Dump the properties dict
dimg["properties"] = {
"domain": self.domain.dump(serialisation),
"taxes": self.taxes,
"order": self.order,
"dtype": np.dtype(self.dtype).name,
"storage": self.storage,
"interpolator": self.interpolator.dump(serialisation),
"name": self.name
}
# Ths is a bugfix. A Domain may be left in kwargs somewhere.
kwargs = self.kwargs.copy()
kwargs.pop("domain", None)
kwargs.pop("storage", None)
dimg["properties"]["kwargs"] = kwargs
# Dump external header(s) (NIfTI is the only one yet)
# TODO: Fix this!
# if self.header.get("nifti_header", None):
# dimg["properties"]["header"] = {
# "nifti_header": dill.dumps(self.header["nifti_header"])
# .decode("utf-8", "backslashreplace")
# }
return dimg
def dump(self):
objdump = super(TImage, self).dump()
# Save TImage-specific properties: mask, header and ResMgr
objdump.update({
"mask": self.mask,
"header": dill.dumps(self.header)
})
if self.resmgr.data.base is not self.data.base:
objdump.update({"resmgr": self.resmgr.dump()})
else:
objdump.update({"resmgr": None})
return objdump
@property
def asTField(self):
......
......@@ -19,10 +19,11 @@ __version__ = (1, 0)
ctype = namedtuple("ctype", ["nbytes", "byteorder", "signed"])
MAGIC = b"TIRLFile"
UINT8 = ctype(1, "big", False) # 8-bit big-endian unsigned integer
UINT64 = ctype(8, "big", False) # 64-bit big-endian unsigned integer
ARRAY_TAG = b"NDArray"
MMAP_TAG = b"MemoryMap"
UINT8 = ctype(1, "little", False) # 8-bit little-endian unsigned integer