Commit 60028036 authored by ihuszar's avatar ihuszar
Browse files

Finished hierachical TIRLObject dumping and loading.

parent 188c1170
This was an idea to create a TIRLFIle object that manages file I/O within TIRL.
Finally the idea was discarded in favour of the functional approach to file I/O.
Removed: 27 January 2020
# class TIRLFile(object):
# """
# TIRLFile object for managing read and write access to TIRLFiles.
#
# """
# __version__ = (1, 0) # TIRLfile version
#
# def __init__(self, f):
# """
# TIRLFile initialisation.
#
# :param f: path to TIRLFile
# :type f: str
#
# """
# # Does the file exist?
# if os.path.isfile(f):
# self.file = f
# else:
# raise FileNotFoundError("File not found: {}".format(f))
#
# # Is it compatible?
# self.check_version()
#
# def check_version(self):
# """ """
# with open(self.file, "rb") as f:
# try:
# magic = f.read(len(MAGIC)) # magic string
# major, minor = bytes2int(f.read(2), UINT8) # version
# except EOFError:
# raise AssertionError("Invalid TIRLFile.")
# if magic != MAGIC:
# raise AssertionError("Invalid TIRLFile.")
# reader_version = self.__version__.split(".")
# if int(reader_version[0]) < major or int(reader_version[1]) < minor:
# raise AssertionError(
# "File is newer than version {}.".format(reader_version))
# self._version = (major, minor)
#
# @property
# def version(self):
# return ".".join(str(v) for v in self._version)
#
# def get_header(self):
# pass
......@@ -183,7 +183,7 @@ class TestLoadSave(unittest.TestCase):
data = np.random.randint(0, 256, (2, 200, 300)).astype(np.uint8)
hdr = AttrDict({"sform": np.random.rand(3, 3), "header_field1": 42})
mask = 0.5 * np.ones((200, 300))
kwargs = {"custom_arument": "whatever"}
kwargs = {"custom_argument": "whatever"}
self.img = TImage(data, tensor_axes=(0,), dtype=np.float32, mask=mask,
name="myimage", order="T", interpolator=None,
storage="mem", header=hdr, **kwargs)
......@@ -208,6 +208,8 @@ class TestLoadSave(unittest.TestCase):
# Load image
newimg = tirl.load(target_file)
self.assertEqual(self.img, newimg)
from tirl import utils as tu
tu.assertEqualRecursive(self.img.__dict__, newimg.__dict__)
# Remove file
os.remove(target_file)
......
......@@ -55,7 +55,7 @@ class TestObjectDumpEncodeDecode(unittest.TestCase):
self.open_files = [(fi, int_file), (ff, float_file)]
def test_encode(self):
encoded_dump, replacements, counter = tf.encode(self.object_dump)
encoded_dump, replacements = tf.encode(self.object_dump)
self.objcode = (encoded_dump, replacements)
# The unencoded dump should be non-JSON-serialisable
with self.assertRaises(TypeError):
......@@ -67,7 +67,7 @@ class TestObjectDumpEncodeDecode(unittest.TestCase):
self.fail("The encoded object dump must be JSON-serialisable.")
def test_decode(self):
encoded_dump, replacements, counter = tf.encode(self.object_dump)
encoded_dump, replacements = tf.encode(self.object_dump)
decoded_dump = tf.decode(encoded_dump, replacements)
# The encoding-decoding process must not modify the initial dump
self.assertDictEqual(self.object_dump, self.original_dump)
......
......@@ -3,6 +3,7 @@ from pydoc import locate
from tirl import settings as ts
from tirl import tirlfile
from tirl import tirlobject
if ts.ENABLE_VISUALISATION:
from tirl import tirlvision
......@@ -21,19 +22,20 @@ def load(fname):
:rtype: TIRLObject
"""
if not os.path.isfile(fname):
raise FileNotFoundError("File at {} does not exist.".format(fname))
dump = tirlfile.load(fname)
try:
txtype = locate(dump["type"])
except Exception as exc:
raise TypeError("Unsupported type: {}".format(dump["type"])) from exc
try:
return txtype.load(fname)
except (NotImplementedError, AttributeError):
raise NotImplementedError("Class-specific load method is not "
"implemented by the requested "
"object ({}).".format(dump["type"]))
return tirlobject.TIRLObject.load(fname)
# if not os.path.isfile(fname):
# raise FileNotFoundError("File at {} does not exist.".format(fname))
# dump = tirlfile.load(fname)
# try:
# txtype = locate(dump["type"])
# except Exception as exc:
# raise TypeError("Unsupported type: {}".format(dump["type"])) from exc
# try:
# return txtype.load(fname)
# except (NotImplementedError, AttributeError):
# raise NotImplementedError("Class-specific load method is not "
# "implemented by the requested "
# "object ({}).".format(dump["type"]))
def expose_package_contents(baseclass, pkg, path, globals=None):
......
......@@ -104,10 +104,11 @@ class Buffer(object):
ret = Buffer(arr=data, fname=fname, file_no=file_no)
return ret
def dump(self):
def _dump(self):
dbuffer = {
"type": ".".join([self.__class__.__module__,
self.__class__.__name__]),
"id": str(id(self)),
"file_no": self.file_no,
"fname": self.fname,
"data": self.data,
......
......@@ -363,18 +363,20 @@ class Domain(TIRLObject):
kwargs = dump.get("kwargs")
# Load offsets
otx_descriptors = dump.get("offset")
otxs = []
for otx_descr in otx_descriptors:
tx = locate(otx_descr["type"])._load(otx_descr)
otxs.append(tx)
# otx_descriptors = dump.get("offset")
# otxs = []
# for otx_descr in otx_descriptors:
# tx = locate(otx_descr["type"])._load(otx_descr)
# otxs.append(tx)
otxs = dump.get("offset")
# Load transformations
tx_descriptors = dump.get("transformations")
transformations = []
for dtx in tx_descriptors:
tx = locate(dtx["type"])._load(dtx)
transformations.append(tx)
# tx_descriptors = dump.get("transformations")
# transformations = []
# for dtx in tx_descriptors:
# tx = locate(dtx["type"])._load(dtx)
# transformations.append(tx)
transformations = dump.get("transformations")
# Obtain domain definition from object dump
extent = dump.get("extent")
......@@ -390,8 +392,8 @@ class Domain(TIRLObject):
return ret
def dump(self):
objdump = super(Domain, self).dump()
def _dump(self):
objdump = super(Domain, self)._dump()
objdump.update({
"name": self.name,
"storage": self.storage,
......@@ -411,6 +413,9 @@ class Domain(TIRLObject):
objdump["offset"] = [tx.dump() for tx in self.offset]
# Add volatile transformations to Domain descriptor dict
# Note: calling the dump method for transformations instead of _dump,
# because transformations may have other delegated TIRLObjects
# (e.g. a domain) delegated to them, which also must be dumped.
objdump["transformations"] = [tx.dump() for tx in self.transformations]
return objdump
......
......@@ -580,13 +580,13 @@ class Interpolator(TIRLObject):
obj.tensor_axes = taxes
return obj
def dump(self):
def _dump(self):
# Note: the interpolation array is not saved when the interpolator is
# dumped. This is because both TField and TImage permit construction
# with an empty interpolator (and so does Interpolator), assuming that
# the interpolation array is the same as the data array. Actually,
# TImage copies the interpolator to create a separate one for the mask.
objdump = super(Interpolator, self).dump()
objdump = super(Interpolator, self)._dump()
objdump.update({
"tensor_axes": self.tensor_axes,
"threads": self.threads,
......
......@@ -80,14 +80,15 @@ class RbfInterpolator(Interpolator):
obj.epsilon = dump.get("epsilon")
domain = dump["kwargs"].get("domain", None)
if domain is not None:
from pydoc import locate
dc = locate(domain["type"])
obj.kwargs.update({"domain": dc._load(domain)})
obj.kwargs.update({"domain": domain})
# if domain is not None:
# from pydoc import locate
# dc = locate(domain["type"])
# obj.kwargs.update({"domain": dc._load(domain)})
return obj
def dump(self):
domain = self.kwargs.pop("domain", None)
indict = super(RbfInterpolator, self).dump()
def _dump(self):
indict = super(RbfInterpolator, self)._dump()
indict.update({
"basis": self.basis,
"model": self.model,
......@@ -95,6 +96,7 @@ class RbfInterpolator(Interpolator):
})
# The attribute check is a bugfix: domain happened to be a dict once!
# (Note: probably because it was not objectified from a dump...)
domain = self.kwargs.pop("domain", None)
if hasattr(domain, "get_voxel_coordinates"):
indict["kwargs"].update({"domain": domain.dump()})
return indict
......
......@@ -103,6 +103,7 @@ TIMAGE_MASK_INTERPOLATOR = \
# Presmoothing before resampling: kernel size defined as the number of sigmas
TIMAGE_PRESMOOTH = True
TIMAGE_PRESMOOTH_KERNELSIZE_NSIGMA = 3
TIMAGE_DEFAULT_SNAPSHOT_EXT = "png"
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Operations #
......
......@@ -307,11 +307,13 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
# will be scrambled, so the use of the TField.fromarray() constructor
# instead is highly encouraged.
buffer_requires = (extent, tensor_shape, order)
if (buffer is not None) and any(arg is None for arg in buffer_requires):
raise ValueError("TField constructor requires input for 'extent', "
"'tensor_shape', 'order', and 'dtype' when a "
"buffer is specified. Consider using the "
"TField.fromarray() constructor instead.")
if buffer is not None:
if any(arg is None for arg in buffer_requires):
raise ValueError(
"TField constructor requires input for 'extent', "
"'tensor_shape', 'order', and 'dtype' when a buffer is "
"specified. Consider using the TField.fromarray() "
"constructor instead.")
obj.order = \
order or obj.properties.get("order", ts.TFIELD_DEFAULT_LAYOUT)
obj.dtype = \
......@@ -761,7 +763,8 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
# Load properties
props = dump.get("properties")
domain = Domain._load(props.get("extent"))
# domain = Domain._load(props.get("extent"))
domain = props.get("extent")
tshape = props.get("tensor_shape")
order = props.get("order")
dtype = np.dtype(props.get("dtype"))
......@@ -770,9 +773,9 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
# Load interpolator (preserving its type)
interpolator = props.get("interpolator")
from pydoc import locate
ipc = locate(interpolator["type"])
interpolator = ipc._load(interpolator)
# from pydoc import locate
# ipc = locate(interpolator["type"])
# interpolator = ipc._load(interpolator)
# Additional keyword arguments
kwargs = props.get("kwargs")
......@@ -786,18 +789,18 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
**kwargs)
return obj
def dump(self):
def _dump(self):
# Create TIRLObject dump
objdump = super(TField, self).dump()
objdump = super(TField, self)._dump()
# Save the data
objdump["data"] = self.data
# Save the peroperties
# Save the properties
objdump["properties"] = {
"extent": self.domain.dump(),
"extent": self.domain.dump(), # calling dump, not _dump!
"tensor_shape": self.tshape,
"order": self.order,
"dtype": np.dtype(self.dtype).str,
"interpolator": self.interpolator.dump(),
"interpolator": self.interpolator.dump(), # dump, not _dump!
"name": self.name,
"storage": self.storage,
"kwargs": self.kwargs
......@@ -1185,9 +1188,15 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
buffer=buffer, offset=offset)
# Convert to in-memory array if necessary
if isinstance(buffer, np.memmap) and (self.storage == ts.MEM):
if buffer.size != self.numel:
raise ValueError("Buffer size does not match the shape of "
"the {}.".format(self.__class__.__name__))
self._data = Buffer(np.asarray(buffer, dtype=self.dtype))
# Convert to memory-mapped hard disk image if necessary
elif not isinstance(buffer, np.memmap) and (self.storage == ts.HDD):
if reduce(mul, buffer.shape) != self.numel:
raise ValueError("Buffer size does not match the shape of {}."
.format(self.__class__.__name__))
fd, fname = tempfile.mkstemp(dir=ts.TWD, prefix="TField_")
m = np.memmap(fname, mode="r+", shape=buffer.shape,
dtype=self.dtype, order="C")
......@@ -1201,6 +1210,10 @@ class TField(TIRLObject, np.lib.mixins.NDArrayOperatorsMixin):
buffer = self._change_memmap_dtype(buffer, self.dtype)
else:
buffer = np.asarray(buffer, dtype=self.dtype)
if buffer.size != self.numel:
raise ValueError("Buffer size does not match the shape of the "
"{}.".format(self.__class__.__name__))
buffer = buffer.reshape(self.shape)
self._data = Buffer(buffer, fname=getattr(buffer, "filename", None))
def attach_default_storage(self):
......
......@@ -139,10 +139,9 @@ class ResolutionManager(TIRLObject):
template = TImage._load(dump)
return cls(template)
def dump(self):
objdump = self.template.dump()
objdump.pop("resmgr", None)
objdump.update(super(ResolutionManager, self).dump())
def _dump(self):
objdump = self.template.dump(resmgr=False) # avoid infinite loop
objdump.update(super(ResolutionManager, self)._dump()) # type, id
return objdump
def get(self, scale, *scales, fov=None, update_chain=True,
......@@ -1609,7 +1608,6 @@ class TImage(TField):
:rtype: TImage
"""
# TODO: purge kwargs before TImage constructor is called!
# Set parameters of the new object before construction
attributes = {
"tensor_axes": self.taxes,
......@@ -1627,6 +1625,8 @@ class TImage(TField):
# Create new object with updated parameters
attributes.update(self.kwargs.copy())
attributes.update(kwargs)
attributes = {k: v for (k, v) in attributes.items()
if k not in self.RESERVED_KWARGS}
obj = TImage.fromarray(self.data, copy=True, **attributes)
# Set mask and resolution manager
......@@ -1640,9 +1640,8 @@ class TImage(TField):
# While there might be exceptions to this, this is a design choice to
# make copy() calls as fast as reasonably possible, as copy calls are
# much more common than the few exceptions mentioned before.
# TODO: Is this OK?
# if self.resmgr is not None:
# obj.resmgr = self.resmgr
if self.resmgr is not None:
obj.resmgr = self.resmgr
return obj
......@@ -1661,18 +1660,19 @@ class TImage(TField):
# Set ResolutionManager
resmgr = dump.get("resmgr", None)
if resmgr:
obj.resmgr = ResolutionManager._load(resmgr)
# obj.resmgr = ResolutionManager._load(resmgr)
obj.resmgr = resmgr
return obj
def dump(self):
objdump = super(TImage, self).dump()
def _dump(self, resmgr=True):
objdump = super(TImage, self)._dump() # type, id
# Save TImage-specific properties: mask, header and ResMgr
objdump.update({
"mask": self.mask,
"header": dill.dumps(self.header)
})
if self.resmgr.data.base is not self.data.base:
if resmgr and (self.resmgr.data.base is not self.data.base):
objdump.update({"resmgr": self.resmgr.dump()})
else:
objdump.update({"resmgr": None})
......@@ -2456,7 +2456,7 @@ class TImage(TField):
fn, ext = os.path.splitext(fname)
fn = fn.rstrip(".")
ext = ext.lstrip(".")
ext = "png" if not ext else ext
ext = ts.TIMAGE_DEFAULT_SNAPSHOT_EXT if not ext else ext
imgdata = tu.to_img(self.data)
if (self.order == "T") and self.taxes:
imgdata = np.moveaxis(imgdata, 0, -1)
......
......@@ -2,6 +2,7 @@
# DEPENDENCIES
import re
import os
import json
import inspect
......@@ -77,55 +78,7 @@ def int2bytes(i, int_type):
return b"".join(result)
# class TIRLFile(object):
# """
# TIRLFile object for managing read and write access to TIRLFiles.
#
# """
# __version__ = (1, 0) # TIRLfile version
#
# def __init__(self, f):
# """
# TIRLFile initialisation.
#
# :param f: path to TIRLFile
# :type f: str
#
# """
# # Does the file exist?
# if os.path.isfile(f):
# self.file = f
# else:
# raise FileNotFoundError("File not found: {}".format(f))
#
# # Is it compatible?
# self.check_version()
#
# def check_version(self):
# """ """
# with open(self.file, "rb") as f:
# try:
# magic = f.read(len(MAGIC)) # magic string
# major, minor = bytes2int(f.read(2), UINT8) # version
# except EOFError:
# raise AssertionError("Invalid TIRLFile.")
# if magic != MAGIC:
# raise AssertionError("Invalid TIRLFile.")
# reader_version = self.__version__.split(".")
# if int(reader_version[0]) < major or int(reader_version[1]) < minor:
# raise AssertionError(
# "File is newer than version {}.".format(reader_version))
# self._version = (major, minor)
#
# @property
# def version(self):
# return ".".join(str(v) for v in self._version)
#
# def get_header(self):
# pass
def save_replacements(f, replacements, compressed=True):
def save_replacements(f, replacements, compressed=False):
"""
Writes the replaced items to a file.
......@@ -134,6 +87,8 @@ def save_replacements(f, replacements, compressed=True):
:param replacements: dictionary of replaced non-serialisable elements
:type replacements: dict
:param compressed:
Note: This feature is not yet available in TIRL 1.0, i.e. all data is
saved to file without compression.
If True, in-memory numerical arrays will be saved to the file after
compression, otherwise they are save as raw byte arrays. Note that
memory-mapped disk arrays will be saved in non-compressed format
......@@ -141,6 +96,8 @@ def save_replacements(f, replacements, compressed=True):
:type compressed: bool
"""
replacements = replacements["decoded_objects"]
# Create lookup header
f.write(LH_TAG)
......@@ -162,26 +119,11 @@ def save_replacements(f, replacements, compressed=True):
item = replacements[key]
offsets.append(f.tell())
if isinstance(item, np.memmap):
save_memmap(f, item)
save_memmap(f, item, compressed=compressed)
elif isinstance(item, np.ndarray):
f.write(ARRAY_TAG)
array_size_field = f.tell()
f.write(b"\x00" * UINT64.nbytes)
f.flush()
os.fsync(f.fileno())
np.save(f, item)
# if compressed:
# np.savez_compressed(f, item)
# else:
# np.savez(f, item)
f.flush()
os.fsync(f.fileno())
eof = f.tell()
f.seek(array_size_field)
f.write(int2bytes(eof - array_size_field - UINT64.nbytes, UINT64))
f.seek(eof)
save_array(f, item, compressed=compressed)
elif isinstance(item, bytes):
save_bytes(f, item)
save_bytes(f, item, compressed=compressed)
else:
raise NotImplementedError()
......@@ -190,7 +132,35 @@ def save_replacements(f, replacements, compressed=True):
f.write(b"".join(int2bytes(val, UINT64) for val in offsets))
def save_bytes(f, b):
def save_array(f, item, compressed=False):
"""
Writes ndarray with header to file.
:param f: file pointer
:type f: File
:param item: array
:type item: np.ndarray
"""
f.write(ARRAY_TAG)
array_size_field = f.tell()
f.write(b"\x00" * UINT64.nbytes) # placeholder
f.flush()
os.fsync(f.fileno())
np.save(f, item)
# if compressed:
# np.savez_compressed(f, item)
# else:
# np.savez(f, item)
f.flush()
os.fsync(f.fileno())
eof = f.tell()
f.seek(array_size_field)
f.write(int2bytes(eof - array_size_field - UINT64.nbytes, UINT64))
f.seek(eof)
def save_bytes(f, b, compressed=False):
"""
Writes a bytearray replacement block to file.
......@@ -207,7 +177,7 @@ def save_bytes(f, b):
os.fsync(f.fileno())
def save_memmap(f, memmap):
def save_memmap(f, memmap, compressed=False):
"""
Writes a block of memory-mapped arrays to an open file. No data compression
is applied.
......@@ -296,9 +266,6 @@ def load_array(f):
"""
import mmap
# if f.read(len(ARRAY_TAG)) != ARRAY_TAG:
# f.seek(-len(ARRAY_TAG), 1)
# raise IndexError("Invalid NDArray at file pointer.")
blocksize = bytes2int(f.read(UINT64.nbytes), UINT64)
fd = os.open(f.name, os.O_RDONLY)
offset = f.tell() - f.tell() % mmap.ALLOCATIONGRANULARITY
......@@ -344,9 +311,6 @@ def load_memmap(f, mode="r+"):
:rtype: np.memmap
"""
# if f.read(len(MMAP_TAG)) != MMAP_TAG:
# f.seek(-len(MMAP_TAG), 1)
# raise IndexError("Invalid MemoryMap object at file pointer.")
js_descr = f.read(bytes2int(f.read(UINT64.nbytes), UINT64)).decode()
shape, dtype, order = json.loads(js_descr)
order = "C" if order else "F"
......@@ -437,45 +401,55 @@ def load(fname):
return object_dump
def encode(node, counter=0):
def encode(node):
"""
Recursive function that traverses an object dump and replaces
unserialisable elements with string codes. The replaced elements are
collated in a dictionary with their respective string codes for separate
processing.
:returns: serialisable dump, replaced objects, counter
:rtype: tuple[Union[dict, tuple, list], dict, int]
:returns: serialisable dump, replaced objects
:rtype: tuple[Union[dict, tuple, list], dict]
"""
replacements = dict()
if "decoded_objects" not in replacements.keys():
replacements["decoded_objects"] = dict()
assert hasattr(node, "__iter__")
if isinstance(node, dict):
collection = dict()
iterator = node.items()
# Set fixed traversing order (needed for object redundancy mapping)