From c68da9a3529024090f73069d23ee78d2c5ccc685 Mon Sep 17 00:00:00 2001 From: Paul McCarthy <pauldmccarthy@gmail.com> Date: Fri, 11 May 2018 16:25:38 +0100 Subject: [PATCH] ENH: fileOrImage works with fsl.data.images. --- fsl/wrappers/wrapperutils.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/fsl/wrappers/wrapperutils.py b/fsl/wrappers/wrapperutils.py index ee056553f..a1bc938d9 100644 --- a/fsl/wrappers/wrapperutils.py +++ b/fsl/wrappers/wrapperutils.py @@ -635,13 +635,25 @@ class _FileOrThing(object): def fileOrImage(*imgargs): """Decorator which can be used to ensure that any NIfTI images are saved to file, and output images can be loaded and returned as ``nibabel`` - image objects. + image objects or :class:`.Image` objects. """ + # keep track of the input argument + # types on each call, so we know + # whether to return a fsl.Image or + # a nibabel image + intypes = [] + def prepIn(workdir, name, val): infile = None + if isinstance(val, (fslimage.Image, nib.nifti1.Nifti1Image)): + intypes.append(type(val)) + + if isinstance(val, fslimage.Image): + val = val.nibImage + if isinstance(val, nib.nifti1.Nifti1Image): infile = val.get_filename() @@ -661,13 +673,28 @@ def fileOrImage(*imgargs): # create an independent in-memory # copy of the image file img = nib.load(path) - return nib.nifti1.Nifti1Image(img.get_data(), None, img.header) + + # if any arguments were fsl images, + # that takes precedence. + if fslimage.Image in intypes: + return fslimage.Image(img.get_data(), header=img.header) + # but if all inputs were file names, + # nibabel takes precedence + elif nib.nifti1.Nifti1Image in intypes or len(intypes) == 0: + return nib.nifti1.Nifti1Image(img.get_data(), None, img.header) + + # this function should not be called + # under any other circumstances + else: + raise RuntimeError('Cannot handle type: {}'.format(intypes)) def decorator(func): fot = _FileOrThing(func, prepIn, prepOut, load, *imgargs) def wrapper(*args, **kwargs): - return fot(*args, **kwargs) + result = fot(*args, **kwargs) + intypes[:] = [] + return result return _update_wrapper(wrapper, func) -- GitLab