wrapperutils.py 6.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
#!/usr/bin/env python
#
# wrapperutils.py -
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#


import os.path as op
import            os
import            inspect
import            tempfile
import            warnings
import            functools
import            collections

import            six
import nibabel as nib

import fsl.utils.tempdir as tempdir
import fsl.data.image    as fslimage


class _BooleanFlag(object):
    def __init__(self, show):
        self.show = show
    def __eq__(self, other):
        return type(other) == type(self) and self.show == other.show


SHOW_IF_TRUE = _BooleanFlag(True)
HIDE_IF_TRUE = _BooleanFlag(False)


def applyArgStyle(style, argmap=None, valmap=None, **kwargs):

    def fmtarg(arg, style):
        if   style in ('-',  '-='):  arg =  '-{}'.format(arg)
        elif style in ('--', '--='): arg = '--{}'.format(arg)
        return arg

    def fmtval(val, style=None):
        if     isinstance(val, collections.Sequence) and \
           not isinstance(val, six.string_types):
            return ' '.join([str(v) for v in val])
        else:
            return str(val)

    if style not in ('-', '--', '-=', '--='):
        raise ValueError('Invalid style: {}'.format(style))

    args = []

    for k, v in kwargs.items():

        k    = argmap.get(k, k)
        mapv = valmap.get(k, fmtval(v, style))
        k    = fmtarg(k, style)

        if mapv in (SHOW_IF_TRUE, HIDE_IF_TRUE):
            if v == mapv.show:
                args.append(k)
        elif '=' in style:
            args.append('{}={}'.format(k, mapv))
        else:
            args.extend((k, mapv))

    return args


def required(*reqargs):
    """Decorator which makes sure that all specified keyword arguments are
    present before calling the decorated function.
    """
    def decorator(func):
        def wrapper(**kwargs):
            for reqarg in reqargs:
                assert reqarg in kwargs
            return func(**kwargs)
        return wrapper
    return decorator


def argsToKwargs(func, args):
    """Given a function, and a sequence of positional arguments destined
    for that function, converts the positional arguments into a dict
    of keyword arguments. Used by the :class:`_FileOrImage` class.
    """
    # getfullargspec is the only way to get the names
    # of positional arguments in Python 2.x. It is
    # deprecated in python 3.5, but not in python 3.6.
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', category=DeprecationWarning)
        spec = inspect.getfullargspec(func)

    kwargs = collections.OrderedDict()
    for name, val in zip(spec.args, args):
        kwargs[name] = val

    return kwargs


RETURN = object()
"""
"""


class _FileOrImage(object):
    """

    Inputs:
      - In-memory nibabel images loaded from a file. The image is replaced with
        its file name.

      - In-memory nibabel images. The image is saved to a temporary file, and
        replaced with the temporary file's name. The file is deleted after the
        function has returned.

    Outputs:
      - File name:  The file name is passed straight through to the function.
      - ``RETURN``: A temporary file name is passed to the function. After the
        function has completed, the image is loaded into memory and the
        temporary file is deleted. The image is returned from the function
        call.
    """


    def __init__(self, *imgargs):
        """
        """
        self.__imgargs = imgargs


    def __call__(self, func):
        """
        """
        return functools.partial(self.__wrapper, func)


    def __wrapper(self, func, *args, **kwargs):
        """
        """

        kwargs.update(argsToKwargs(func, args))

        # Create a tempdir to store any temporary
        # input/output images, but don't change
        # into it, as file paths passed to the
        # function may be relative.
        with tempdir.tempdir(changeto=False) as td:

            kwargs, infiles, outfiles = self.__prepareArgs(td, kwargs)

            # Call the function
            result  = func(**kwargs)

            # Load the output images that
            # were specified as RETURN
            outimgs = []
            for of in outfiles:

                # output file didn't get created
                if not op.exists(of):
                    oi = None

                # load the file, and create
                # an in-memory copy (the file
                # is going to get deleted)
                else:
                    oi = nib.load(of)
                    oi = nib.nifti1.Nifti1Image(oi.get_data(), None, oi.header)

                outimgs.append(oi)

            return tuple([result] + outimgs)


    def __prepareArgs(self, workdir, kwargs):
        """
        """

        kwargs   = dict(kwargs)
        infiles  = []
        outfiles = []

        for imgarg in self.__imgargs:

            img = kwargs.get(imgarg, None)

            # Not specified, nothing to do
            if img is None:
                continue

            # This is an input image which has
            # been specified as an in-memory
            # nibabel image. if the image has
            # a backing file, replace the image
            # object with the file name.
            # Otherwise, save the image out to
            # a temporary file, and replace the
            # image with the file name.
            if isinstance(img, nib.nifti1.Nifti1Image):
                imgfile = img.get_filename()

                # in-memory image - we have
                # to save it out to a file
                if imgfile is None:

                    hd, imgfile = tempfile.mkstemp(fslimage.defaultExt())

                    os.close(hd)
                    img.to_filename(imgfile)
                    infiles.append(imgfile)

                # replace the image with its
                # file name
                kwargs[img] = imgfile

            # This is an output image, and the
            # caller has requested that it be
            # returned from the function call
            # as an in-memory image.
            if img == RETURN:
                kwargs[imgarg] = '{}.nii.gz'.format(imgarg)
                outfiles.append(imgarg)

        return kwargs, infiles, outfiles


fileOrImage = _FileOrImage