imagewrapper.py 42 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
#
# imagewrapper.py - The ImageWrapper class.
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
7
8
"""This module provides the :class:`ImageWrapper` class, which can be used
to manage data access to ``nibabel`` NIFTI images.
9

10
11
12
13
14

Terminology
-----------


15
There are some confusing terms used in this module, so it may be useful to
16
17
get their definitions straight:

18
19
20
21
22
23
24
25
26
27
28
29
30
  - *Coverage*:    The portion of an image that has been covered in the data
                   range calculation. The ``ImageWrapper`` keeps track of
                   the coverage for individual volumes within a 4D image (or
                   slices in a 3D image).

  - *Slice*:       Portion of the image data which is being accessed. A slice
                   comprises either a tuple of ``slice`` objects (or integers),
                   or a sequence of ``(low, high)`` tuples, specifying the
                   index range into each image dimension that is covered by
                   the slice.

  - *Expansion*:   A sequence of ``(low, high)`` tuples, specifying an
                   index range into each image dimension, that is used to
31
                   *expand* the *coverage* of an image, based on a given set
32
33
34
35
36
                   of *slices*.

  - *Fancy slice*: Any object which is used to slice an array, and is not
                   an ``int``, ``slice``, or ``Ellipsis``, or sequence of
                   these.
37
38
39
"""


Paul McCarthy's avatar
Paul McCarthy committed
40
41
42
43
import                    logging
import                    collections
import collections.abc as abc
import itertools       as it
44

45
46
import numpy     as np
import nibabel   as nib
47

48
49
50
import fsl.utils.notifier    as notifier
import fsl.utils.naninfrange as nir
import fsl.utils.idle        as idle
51
52
53
54
55
56


log = logging.getLogger(__name__)


class ImageWrapper(notifier.Notifier):
57
58
59
    """The ``ImageWrapper`` class is a convenience class which manages data
    access to ``nibabel`` NIFTI images. The ``ImageWrapper`` class can be
    used to:
60

61

62
      - Control whether the image is loaded into memory, or kept on disk
63

64
      - Incrementally update the known image data range, as more image
65
66
67
        data is read in.


68
69
70
    *In memory or on disk?*

    The image data will be kept on disk, and accessed through the
Paul McCarthy's avatar
Paul McCarthy committed
71
72
    ``nibabel.Nifti1Image.dataobj`` (or ``nibabel.Nifti2Image.dataobj``) array
    proxy, if:
73
74
75
76
77
78

     - The ``loadData`` parameter to :meth:`__init__` is ``False``.
     - The :meth:`loadData` method never gets called.
     - The image data is not modified (via :meth:`__setitem__`.

    If any of these conditions do not hold, the image data will be loaded into
79
    memory and accessed directly.
80
81
82
83


    *Image dimensionality*

84

85
86
87
    The ``ImageWrapper`` abstracts away trailing image dimensions of length 1.
    This means that if the header for a NIFTI image specifies that the image
    has four dimensions, but the fourth dimension is of length 1, you do not
88
89
90
91
    need to worry about indexing that fourth dimension. However, all NIFTI
    images will be presented as having at least three dimensions, so if your
    image header specifies a third dimension of length 1, you will still
    need provide an index of 0 for that dimensions, for all data accesses.
92
93


94
95
    *Data access*

96

97
98
99
100
    The ``ImageWrapper`` can be indexed in one of two ways:

       - With basic ``numpy``-like multi-dimensional array slicing (with step
         sizes of 1)
101

102
103
104
105
106
       - With boolean array indexing, where the boolean/mask array has the
         same shape as the image data.

    See https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html for
    more details on numpy indexing.
107

108

109
110
    *Data range*

111

112
113
114
    In order to avoid the computational overhead of calculating the image data
    range (its minimum/maximum values) when an image is first loaded in, an
    ``ImageWrapper`` incrementally updates the known image data range as data
115
    is accessed. The ``ImageWrapper`` keeps track of the image data *coverage*,
116
117
118
119
120
121
    the portion of the image which has already been considered in the data
    range calculation. When data from a region of the image not in the coverage
    is accessed, the coverage is expanded to include this region. The coverage
    is always expanded in a rectilinear manner, i.e. the coverage is always
    rectangular for a 2D image, or cuboid for a 3D image.

122

123
124
125
126
127
    For a 4D image, the ``ImageWrapper`` internally maintains a separate
    coverage and known data range for each 3D volume within the image. For a 3D
    image, separate coverages and data ranges are stored for each 2D slice.


128
129
130
131
    The ``ImageWrapper`` implements the :class:`.Notifier` interface.
    Listeners can register to be notified whenever the known image data range
    is updated. The data range can be accessed via the :attr:`dataRange`
    property.
132

133

134
    The ``ImageWrapper`` class uses the following functions (also defined in
135
136
137
138
139
140
    this module) to keep track of the portion of the image that has currently
    been included in the data range calculation:

    .. autosummary::
       :nosignatures:

141
142
       isValidFancySliceObj
       canonicalSliceObj
143
144
145
       sliceObjToSliceTuple
       sliceTupleToSliceObj
       sliceCovered
146
       calcExpansion
147
       adjustCoverage
148
149
    """

150

151
152
153
154
155
156
    def __init__(self,
                 image,
                 name=None,
                 loadData=False,
                 dataRange=None,
                 threaded=False):
157
        """Create an ``ImageWrapper``.
158

Paul McCarthy's avatar
Paul McCarthy committed
159
        :arg image:     A ``nibabel.Nifti1Image`` or ``nibabel.Nifti2Image``.
160

161
        :arg name:      A name for this ``ImageWrapper``, solely used for
162
                        debug log messages.
163

164
165
166
167
168
169
        :arg loadData:  If ``True``, the image data is loaded into memory.
                        Otherwise it is kept on disk (and data access is
                        performed through the ``nibabel.Nifti1Image.dataobj``
                        array proxy).

        :arg dataRange: A tuple containing the initial ``(min, max)``  data
170
171
                        range to use. See the :meth:`reset` method for
                        important information about this parameter.
172
173
174
175

        :arg threaded:  If ``True``, the data range is updated on a
                        :class:`.TaskThread`. Otherwise (the default), the
                        data range is updated directly on reads/writes.
176
        """
177

178
179
        import fsl.data.image as fslimage

180
181
182
        self.__image      = image
        self.__name       = name
        self.__taskThread = None
183

184
185
186
187
188
189
190
191
        # Save the number of 'real' dimensions,
        # that is the number of dimensions minus
        # any trailing dimensions of length 1
        self.__numRealDims = len(image.shape)
        for d in reversed(image.shape):
            if d == 1: self.__numRealDims -= 1
            else:      break

192
193
194
195
        # Degenerate case - less
        # than three real dimensions
        if self.__numRealDims < 3:
            self.__numRealDims = min(3, len(image.shape))
196

197
198
199
200
        # And save the number of
        # 'padding' dimensions too.
        self.__numPadDims = len(image.shape) - self.__numRealDims

201
202
203
204
205
206
207
        # Too many shapes! Figure out
        # what shape we should present
        # the data as (e.g. at least 3
        # dimensions). This is used in
        # __getitem__ to force the
        # result to have the correct
        # dimensionality.
208
        self.__canonicalShape = fslimage.canonicalShape(image.shape)
209

210
211
212
213
214
215
216
217
218
219
        # The internal state is stored
        # in these attributes - they're
        # initialised in the reset method.
        self.__range     = None
        self.__coverage  = None
        self.__volRanges = None
        self.__covered   = False

        self.reset(dataRange)

220
221
222
223
224
        # We keep an internal ref to
        # the data numpy array if/when
        # it is loaded in memory
        self.__data = None
        if loadData or image.in_memory:
225
226
            self.loadData()

227
        if threaded:
Paul McCarthy's avatar
Paul McCarthy committed
228
            self.__taskThread = idle.TaskThread()
229
            self.__taskThread.daemon = True
230
231
232
233
234
235
236
237
            self.__taskThread.start()


    def __del__(self):
        """If this ``ImageWrapper`` was created with ``threaded=True``,
        the :class:`.TaskThread` is stopped.
        """
        self.__image = None
238
        self.__data  = None
239
240
        if self.__taskThread is not None:
            self.__taskThread.stop()
Paul McCarthy's avatar
Paul McCarthy committed
241
            self.__taskThread = None
242

243

244
245
246
247
248
249
250
251
252
    def getTaskThread(self):
        """If this ``ImageWrapper`` was created with ``threaded=True``,
        this method returns the ``TaskThread`` that is used for running
        data range calculation tasks. Otherwise, this method returns
        ``False``.
        """
        return self.__taskThread


253
254
255
256
    def reset(self, dataRange=None):
        """Reset the internal state and known data range of this
        ``ImageWrapper``.

257

258
        :arg dataRange: A tuple containing the initial ``(min, max)``  data
259
                        range to use.
260
261
262


        .. note:: The ``dataRange`` parameter is intended for situations where
263
264
                  the image data range is known in advance (e.g. it was
                  calculated earlier, and the image is being re-loaded). If a
265
266
                  ``dataRange`` is passed in, it will *not* be overwritten by
                  any range calculated from the data, unless the calculated
267
                  data range is wider than the provided ``dataRange``.
268
        """
269

270
271
272
        if dataRange is None:
            dataRange = None, None

273
        image =             self.__image
274
275
276
        ndims =             self.__numRealDims - 1
        nvols = image.shape[self.__numRealDims - 1]

277
278
        # The current known image data range. This
        # gets updated as more image data gets read.
279
        self.__range = dataRange
280

281
282
283
284
285
286
287
288
289
290
291
292
        # The coverage array is used to keep track of
        # the portions of the image which have been
        # considered in the data range calculation.
        # We use this coverage to avoid unnecessarily
        # re-calculating the data range on the same
        # part of the image.
        #
        # First of all, we're going to store a separate
        # 'coverage' for each 2D slice in the 3D image
        # (or 3D volume for 4D images). This effectively
        # means a seaprate coverage for each index in the
        # last 'real' image dimension (see above).
293
        #
294
295
296
297
298
299
        # For each slice/volume, the the coverage is
        # stored as sequences of (low, high) indices, one
        # for each dimension in the slice/volume (e.g.
        # row/column for a slice, or row/column/depth
        # for a volume).
        #
300
        # All of these indices are stored in a numpy array:
301
302
        #   - first dimension:  low/high index
        #   - second dimension: image dimension
303
        #   - third dimension:  slice/volume index
304
        self.__coverage = np.zeros((2, ndims, nvols), dtype=np.float32)
305

306
307
        # Internally, we calculate and store the
        # data range for each volume/slice/vector
308
309
        #
        # We use nan as a placeholder, so the
310
311
312
        # dtype must be non-integral. The
        # len(dtype) check takes into account
        # structured data (e.g. RGB)
313
        dtype = self.__image.get_data_dtype()
314
        if np.issubdtype(dtype, np.integer) or len(dtype) > 0:
315
316
317
            dtype = np.float32
        self.__volRanges = np.zeros((nvols, 2),
                                    dtype=dtype)
318
319
320

        self.__coverage[ :] = np.nan
        self.__volRanges[:] = np.nan
321

322
323
324
325
326
        # This flag is set to true if/when the
        # full image data range becomes known
        # (i.e. when all data has been loaded in).
        self.__covered = False

327

328
329
330
331
332
333
334
335
336
337
    @property
    def dataRange(self):
        """Returns the currently known data range as a tuple of ``(min, max)``
        values.
        """
        # If no image data has been accessed, we
        # default to whatever is stored in the
        # header (which may or may not contain
        # useful values).
        low, high = self.__range
338
        hdr       = self.__image.header
339
340
341
342
343
344

        if low  is None: low  = float(hdr['cal_min'])
        if high is None: high = float(hdr['cal_max'])

        return low, high

345

346
347
348
349
350
351
352
353
    @property
    def covered(self):
        """Returns ``True`` if this ``ImageWrapper`` has read the entire
        image data, ``False`` otherwise.
        """
        return self.__covered


354
355
356
357
358
359
360
361
362
    @property
    def shape(self):
        """Returns the shape that the image data is presented as. This is
        the same as the underlying image shape, but with trailing dimensions
        of length 1 removed, and at least three dimensions.
        """
        return self.__canonicalShape


363
364
365
366
367
368
369
370
371
372
    def coverage(self, vol):
        """Returns the current image data coverage for the specified volume
        (for a 4D image, slice for a 3D image, or vector for a 2D images).

        :arg vol: Index of the volume/slice/vector to return the coverage
                  for.

        :returns: The coverage for the specified volume, as a ``numpy``
                  array of shape ``(nd, 2)``, where ``nd`` is the number
                  of dimensions in the volume.
373
374
375

        .. note:: If the specified volume is not covered, the returned array
                  will contain ``np.nan`` values.
376
        """
377
        return np.array(self.__coverage[..., vol])
378

379

380
381
382
383
    def loadData(self):
        """Forces all of the image data to be loaded into memory.

        .. note:: This method will be called by :meth:`__init__` if its
Paul McCarthy's avatar
Paul McCarthy committed
384
385
                  ``loadData`` parameter is ``True``. It will also be called
                  on all write operations (see :meth:`__setitem__`).
386
        """
387
388
        if self.__data is None:
            self.__data = np.asanyarray(self.__image.dataobj)
389

390

391
    def __getData(self, sliceobj, isTuple=False):
392
        """Retrieves the image data at the location specified by ``sliceobj``.
393
394
395
396
397
398

        :arg sliceobj: Something which can be used to slice an array, or
                       a sequence of (low, high) index pairs.

        :arg isTuple:  Set to ``True`` if ``sliceobj`` is a sequence of
                       (low, high) index pairs.
399
400
        """

401
        if isTuple:
402
            sliceobj = sliceTupleToSliceObj(sliceobj)
403

404
        # If the image has not been loaded
405
        # into memory, we can use the nibabel
406
407
408
        # ArrayProxy. Otheriwse if it is in
        # memory, we can access it directly.
        #
409
        # Note also that if the caller has
410
411
412
413
414
415
416
        # given us a 'fancy' slice object (a
        # boolean numpy array), but the image
        # data is not in memory, we can't access
        # the data, as the nibabel ArrayProxy
        # (the dataobj attribute) cannot handle
        # fancy indexing. In this case an error
        # will be raised.
417
418
        if self.__data is not None: return self.__data[         sliceobj]
        else:                       return self.__image.dataobj[sliceobj]
419
420


421
422
423
424
    def __imageIsCovered(self):
        """Returns ``True`` if all portions of the image have been covered
        in the data range calculation, ``False`` otherwise.
        """
425

426
        shape  = self.__image.shape
427
        slices = [[0, s] for s in shape]
428
        return sliceCovered(slices, self.__coverage)
429

430

431
432
433
    def __expandCoverage(self, slices):
        """Expands the current image data range and coverage to encompass the
        given ``slices``.
434
        """
435

436
437
        _, expansions = calcExpansion(slices, self.__coverage)
        expansions    = collapseExpansions(expansions, self.__numRealDims - 1)
438

439
440
441
442
443
444
445
446
447
448
449
450
        log.debug('Updating image %s data range [slice: %s] '
                  '(current range: [%s, %s]; '
                  'number of expansions: %s; '
                  'current coverage: %s; '
                  'volume ranges: %s)',
                  self.__name,
                  slices,
                  self.__range[0],
                  self.__range[1],
                  len(expansions),
                  self.__coverage,
                  self.__volRanges)
451
452
453
454
455
456
457
458
459

        # As we access the data for each expansions,
        # we want it to have the same dimensionality
        # as the full image, so we can access data
        # for each volume in the image separately.
        # So we squeeze out the padding dimensions,
        # but not the volume dimension.
        squeezeDims = tuple(range(self.__numRealDims,
                                  self.__numRealDims + self.__numPadDims))
460

461
462
463
464
465
        # The calcExpansion function splits up the
        # expansions on volumes - here we calculate
        # the min/max per volume/expansion, and
        # iteratively update the stored per-volume
        # coverage and data range.
Paul McCarthy's avatar
Paul McCarthy committed
466
        for exp in expansions:
467

468
469
            data     = self.__getData(exp, isTuple=True)
            data     = data.squeeze(squeezeDims)
470
471
472
            vlo, vhi = exp[self.__numRealDims - 1]

            for vi, vol in enumerate(range(vlo, vhi)):
473

474
475
                oldvlo, oldvhi = self.__volRanges[vol, :]
                voldata        = data[..., vi]
476
                newvlo, newvhi = nir.naninfrange(voldata)
477

478
479
480
481
482
483
                if np.isnan(newvlo) or \
                   (not np.isnan(oldvlo) and oldvlo < newvlo):
                    newvlo = oldvlo
                if np.isnan(newvhi) or \
                   (not np.isnan(oldvhi) and oldvhi > newvhi):
                    newvhi = oldvhi
484

485
                # Update the stored range and
486
                # coverage for each volume
487
488
489
                self.__volRanges[vol, :]  = newvlo, newvhi
                self.__coverage[..., vol] = adjustCoverage(
                    self.__coverage[..., vol], exp)
490

491
492
493
        # Calculate the new known data
        # range over the entire image
        # (i.e. over all volumes).
494
        newmin, newmax = nir.naninfrange(self.__volRanges)
495

496
        oldmin, oldmax = self.__range
497
        self.__range   = (newmin, newmax)
498
499
        self.__covered = self.__imageIsCovered()

500
501
        if any((oldmin is None, oldmax is None)) or \
           not np.all(np.isclose([oldmin, oldmax], [newmin, newmax])):
502
503
504
505
506
507
            log.debug('Image %s range changed: [%s, %s] -> [%s, %s]',
                      self.__name,
                      oldmin,
                      oldmax,
                      newmin,
                      newmax)
508
509
            self.notify()

510
511
512
513
514
515
516
517

    def __updateDataRangeOnRead(self, slices, data):
        """Called by :meth:`__getitem__`. Calculates the minimum/maximum
        values of the given data (which has been extracted from the portion of
        the image specified by ``slices``), and updates the known data range
        of the image.

        :arg slices: A tuple of tuples, each tuple being a ``(low, high)``
518
519
                     index pair, one for each dimension in the image.

520
521
522
523
524
525
526
527
        :arg data:   The image data at the given ``slices`` (as a ``numpy``
                     array).
        """

        # TODO You could do something with
        #      the provided data to avoid
        #      reading it in again.

528
529
530
531
532
        if self.__taskThread is None:
            self.__expandCoverage(slices)
        else:
            name = '{}_read_{}'.format(id(self), slices)
            if not self.__taskThread.isQueued(name):
533
                self.__taskThread.enqueue(
534
                    self.__expandCoverage, slices, taskName=name)
535

536

537
538
539
540
541
542
    def __updateDataRangeOnWrite(self, slices, data):
        """Called by :meth:`__setitem__`. Assumes that the image data has
        been changed (the data at ``slices`` has been replaced with ``data``.
        Updates the image data coverage, and known data range accordingly.

        :arg slices: A tuple of tuples, each tuple being a ``(low, high)``
543
544
                     index pair, one for each dimension in the image.

545
        :arg data:   The image data at the given ``slices`` (as a ``numpy``
546
                     array).
547
548
549
550
551
552
553
554
        """

        overlap = sliceOverlap(slices, self.__coverage)

        # If there's no overlap between the written
        # area and the current coverage, then it's
        # easy - we just expand the coverage to
        # include the newly written area.
555
        #
556
557
558
559
560
561
562
563
564
        # But if there is overlap between the written
        # area and the current coverage, things are
        # more complicated, because the portion of
        # the image that has been written over may
        # have contained the currently known data
        # minimum/maximum. We have no way of knowing
        # this, so we have to reset the coverage (on
        # the affected volumes), and recalculate the
        # data range.
565
566
567
568
569
570
571
572
        if overlap in (OVERLAP_SOME, OVERLAP_ALL):

            # TODO Could you store the location of the
            #      data minimum/maximum (in each volume),
            #      so you know whether resetting the
            #      coverage is necessary?
            lowVol, highVol = slices[self.__numRealDims - 1]

573
574
575
576
577
578
579
580
581
582
583
584
585
            # We create a single slice which
            # encompasses the given slice, and
            # all existing coverages for each
            # volume in the given slice. The
            # data range for this slice will
            # be recalculated.
            slices = adjustCoverage(self.__coverage[:, :, lowVol], slices)
            for vol in range(lowVol + 1, highVol):
                slices = adjustCoverage(slices, self.__coverage[:, :, vol].T)

            slices = np.array(slices.T, dtype=np.uint32)
            slices = tuple(it.chain(map(tuple, slices), [(lowVol, highVol)]))

586
587
588
589
590
591
592
593
594
            log.debug('Image %s data written - clearing known data '
                      'range on volumes %s - %s (write slice: %s; '
                      'coverage: %s; volRanges: %s)',
                      self.__name,
                      lowVol,
                      highVol,
                      slices,
                      self.__coverage[:, :, lowVol:highVol],
                      self.__volRanges[lowVol:highVol, :])
595

596
            for vol in range(lowVol, highVol):
597
598
                self.__coverage[:, :, vol]    = np.nan
                self.__volRanges[     vol, :] = np.nan
599

600
601
602
603
604
605

        if self.__taskThread is None:
            self.__expandCoverage(slices)
        else:
            name = '{}_write_{}'.format(id(self), slices)
            if not self.__taskThread.isQueued(name):
606
                self.__taskThread.enqueue(
607
                    self.__expandCoverage, slices, taskName=name)
608

609

610
    def __getitem__(self, sliceobj):
611
612
613
614
        """Returns the image data for the given ``sliceobj``, and updates
        the known image data range if necessary.

        :arg sliceobj: Something which can slice the image data.
615
        """
616

617
        log.debug('Getting image data: %s', sliceobj)
618

619
620
621
622
623
        shape              = self.__canonicalShape
        realShape          = self.__image.shape
        sliceobj           = canonicalSliceObj(   sliceobj, shape)
        fancy              = isValidFancySliceObj(sliceobj, shape)
        expNdims, expShape = expectedShape(       sliceobj, shape)
624

625
        # TODO Cache 3D images for large 4D volumes,
626
        #      so you don't have to hit the disk?
627

628
629
630
        # Make the slice object compatible with the
        # actual image shape, and retrieve the data.
        sliceobj = canonicalSliceObj(sliceobj, realShape)
631
        data     = self.__getData(sliceobj)
632

633
        # Update data range for the
634
        # data that we just read in
635
        if not self.__covered:
636

637
            slices = sliceObjToSliceTuple(sliceobj, realShape)
638

639
            if not sliceCovered(slices, self.__coverage):
640
                self.__updateDataRangeOnRead(slices, data)
641

642
643
        # Make sure that the result has the
        # shape that the caller is expecting.
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
        if fancy: data = data.reshape((data.size, ))
        else:     data = data.reshape(expShape)

        # If expNdims == 0, we should
        # return a scalar. If expNdims
        # == 0, but data.size != 1,
        # something is wrong somewhere
        # (and is not being handled
        # here).
        if expNdims == 0 and data.size == 1:

            # Funny behaviour with numpy scalar arrays.
            # data[()] returns a numpy scalar (which is
            # what we want). But data.item() returns a
            # python scalar. And if the data is a
            # ndarray with 0 dims, data[0] will raise
            # an error!
            data = data[()]
662

663
        return data
664
665


666
667
668
    def __setitem__(self, sliceobj, values):
        """Writes the given ``values`` to the image at the given ``sliceobj``.

669

670
671
672
        :arg sliceobj: Something which can be used to slice the array.
        :arg values:   Data to write to the image.

673

674
        .. note:: Modifying image data will cause the entire image to be
675
                  loaded into memory.
676
677
        """

678
679
680
681
682
683
684
685
686
687
688
        realShape = self.__image.shape
        sliceobj  = canonicalSliceObj(   sliceobj, realShape)
        slices    = sliceObjToSliceTuple(sliceobj, realShape)

        # If the image shape does not match its
        # 'display' shape (either less three
        # dims, or  has trailing dims of length
        # 1), we might need to re-shape the
        # values to prevent numpy from raising
        # an error in the assignment below.
        if realShape != self.__canonicalShape:
689

690
691
692
693
            expNdims, expShape = expectedShape(sliceobj, realShape)

            # If we are slicing a scalar, the
            # assigned value has to be scalar.
Paul McCarthy's avatar
Paul McCarthy committed
694
            if expNdims == 0 and isinstance(values, abc.Sequence):
695
696
697
698

                if len(values) > 1:
                    raise IndexError('Invalid assignment: [{}] = {}'.format(
                        sliceobj, len(values)))
699

700
                values = np.array(values).flatten()[0]
701

702
            # Make sure that the values
703
704
            # have a compatible shape.
            else:
705

706
707
708
                values = np.array(values)
                if values.shape != expShape:
                    values = values.reshape(expShape)
709
710
711
712
713
714
715

        # The image data has to be in memory
        # for the data to be changed. If it's
        # already in memory, this call won't
        # have any effect.
        self.loadData()

716
        self.__data[sliceobj] = values
717
718
719
        self.__updateDataRangeOnWrite(slices, values)


720
721
722
723
724
def isValidFancySliceObj(sliceobj, shape):
    """Returns ``True`` if the given ``sliceobj`` is a valid and fancy slice
    object.

    ``nibabel`` refers to slice objects as "fancy" if they comprise anything
725
726
727
    but tuples of integers and simple ``slice`` objects. The ``ImageWrapper``
    class supports one type of "fancy" slicing, where the ``sliceobj`` is a
    boolean ``numpy`` array of the same shape as the image.
728
729
730
731
732
733
734
735

    This function returns ``True`` if the given ``sliceobj`` adheres to these
    requirements, ``False`` otherwise.
    """

    # We only support boolean numpy arrays
    # which have the same shape as the image
    return (isinstance(sliceobj, np.ndarray) and
736
            sliceobj.dtype == bool           and
737
            np.prod(sliceobj.shape) == np.prod(shape))
738
739
740
741


def canonicalSliceObj(sliceobj, shape):
    """Returns a canonical version of the given ``sliceobj``. See the
Paul McCarthy's avatar
Paul McCarthy committed
742
    ``nibabel.fileslice.canonical_slicers`` function.
743
744
    """

745
    # Fancy slice objects must have
746
747
748
749
750
751
752
753
    # the same shape as the data
    if isValidFancySliceObj(sliceobj, shape):
        return sliceobj.reshape(shape)

    else:

        if not isinstance(sliceobj, tuple):
            sliceobj = (sliceobj,)
754

755
756
        if len(sliceobj) > len(shape):
            sliceobj = sliceobj[:len(shape)]
757

758
        return nib.fileslice.canonical_slicers(sliceobj, shape)
759

760

761
762
763
764
765
766
767
768
769
770
771
772
773
774
def expectedShape(sliceobj, shape):
    """Given a slice object, and the shape of an array to which
    that slice object is going to be applied, returns the expected
    shape of the result.

    .. note:: It is assumed that the ``sliceobj`` has been passed through
              the :func:`canonicalSliceObj` function.

    :arg sliceobj: Something which can be used to slice an array
                   of shape ``shape``.

    :arg shape:    Shape of the array being sliced.

    :returns:      A tuple containing:
775

776
                     - Expected number of dimensions of the result
777

778
779
780
                     - Expected shape of the result (or ``None`` if
                       ``sliceobj`` is fancy).
    """
781

782
783
784
785
786
787
788
789
790
791
792
    if isValidFancySliceObj(sliceobj, shape):
        return 1, None

    # Truncate some dimensions from the
    # slice object if it has too many
    # (e.g. trailing dims of length 1).
    elif len(sliceobj) > len(shape):
        sliceobj = sliceobj[:len(shape)]

    # Figure out the number of dimensions
    # that the result should have, given
793
    # this slice object.
794
795
796
797
    expShape = []

    for i in range(len(sliceobj)):

798
        # Each dimension which has an
799
800
801
802
803
804
        # int slice will be collapsed
        if isinstance(sliceobj[i], int):
            continue

        start = sliceobj[i].start
        stop  = sliceobj[i].stop
805

806
807
808
        if start is None: start = 0
        if stop  is None: stop  = shape[i]

Paul McCarthy's avatar
Paul McCarthy committed
809
810
        stop = min(stop, shape[i])

811
812
813
814
815
        expShape.append(stop - start)

    return len(expShape), expShape


816
817
818
def sliceObjToSliceTuple(sliceobj, shape):
    """Turns an array slice object into a tuple of (low, high) index
    pairs, one pair for each dimension in the given shape
819
820
821
822
823

    :arg sliceobj: Something which can be used to slice an array of shape
                   ``shape``.

    :arg shape:    Shape of the array being sliced.
824
825
    """

826
827
828
    if isValidFancySliceObj(sliceobj, shape):
        return tuple((0, s) for s in shape)

829
830
    indices = []

831
832
    # The sliceobj could be a single sliceobj
    # or integer, instead of a tuple
Paul McCarthy's avatar
Paul McCarthy committed
833
    if not isinstance(sliceobj, abc.Sequence):
834
835
        sliceobj = [sliceobj]

836
    # Turn e.g. array[6] into array[6, :, :]
837
838
839
840
    if len(sliceobj) != len(shape):
        missing  = len(shape) - len(sliceobj)
        sliceobj = list(sliceobj) + [slice(None) for i in range(missing)]

841
842
    for dim, s in enumerate(sliceobj):

843
        # each element in the slices tuple should
844
845
846
847
848
849
850
851
852
853
854
855
        # be a slice object or an integer
        if isinstance(s, slice): i = [s.start, s.stop]
        else:                    i = [s,       s + 1]

        if i[0] is None: i[0] = 0
        if i[1] is None: i[1] = shape[dim]

        indices.append(tuple(i))

    return tuple(indices)


856
857
def sliceTupleToSliceObj(slices):
    """Turns a sequence of (low, high) index pairs into a tuple of array
858
    ``slice`` objects.
859
860

    :arg slices: A sequence of (low, high) index pairs.
861
862
863
864
865
866
867
868
869
870
    """

    sliceobj = []

    for lo, hi in slices:
        sliceobj.append(slice(lo, hi, 1))

    return tuple(sliceobj)


871
def adjustCoverage(oldCoverage, slices):
872
873
874
    """Adjusts/expands the given ``oldCoverage`` so that it covers the
    given set of ``slices``.

875
    :arg oldCoverage: A ``numpy`` array of shape ``(2, n)`` containing
876
877
                      the (low, high) index pairs for ``n`` dimensions of
                      a single slice/volume in the image.
878

879
880
881
882
    :arg slices:      A sequence of (low, high) index pairs. If ``slices``
                      contains more dimensions than are specified in
                      ``oldCoverage``, the trailing dimensions are ignored.

883
    :return: A ``numpy`` array containing the adjusted/expanded coverage.
884
885
    """

886
    newCoverage = np.zeros(oldCoverage.shape, dtype=oldCoverage.dtype)
887

888
    for dim in range(oldCoverage.shape[1]):
889

890
891
        low,      high      = slices[        dim]
        lowCover, highCover = oldCoverage[:, dim]
892

893
894
        if np.isnan(lowCover)  or low  < lowCover:  lowCover  = low
        if np.isnan(highCover) or high > highCover: highCover = high
895

896
        newCoverage[:, dim] = lowCover, highCover
897
898
899

    return newCoverage

900

901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
OVERLAP_ALL = 0
"""Indicates that the slice is wholly contained within the coverage.  This is
a return code for the :func:`sliceOverlap` function.
"""


OVERLAP_SOME = 1
"""Indicates that the slice partially overlaps with the coverage. This is a
return code for the :func:`sliceOverlap` function.
"""


OVERLAP_NONE = 2
"""Indicates that the slice does not overlap with the coverage. This is a
return code for the :func:`sliceOverlap` function.
"""


def sliceOverlap(slices, coverage):
    """Determines whether the given ``slices`` overlap with the given
    ``coverage``.

    :arg slices:    A sequence of (low, high) index pairs, assumed to cover
                    all image dimensions.
    :arg coverage:  A ``numpy`` array of shape ``(2, nd, nv)`` (where ``nd``
                    is the number of dimensions being covered, and ``nv`` is
                    the number of volumes (or vectors/slices) in the image,
                    which contains the (low, high) index pairs describing
                    the current image coverage.

    :returns: One of the following codes:
932

933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
              .. autosummary::

              OVERLAP_ALL
              OVERLAP_SOME
              OVERLAP_NONE
    """

    numDims         = coverage.shape[1]
    lowVol, highVol = slices[numDims]

    # Overlap state is calculated for each volume
    overlapStates = np.zeros(highVol - lowVol)

    for i, vol in enumerate(range(lowVol, highVol)):

        state = OVERLAP_ALL

        for dim in range(numDims):

            lowCover, highCover = coverage[:, dim, vol]
953
            lowSlice, highSlice = slices[     dim]
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977

            # No coverage
            if np.isnan(lowCover) or np.isnan(highCover):
                state = OVERLAP_NONE
                break

            # The slice is contained within the
            # coverage on this dimension - check
            # the other dimensions.
            if lowSlice >= lowCover and highSlice <= highCover:
                continue

            # The slice does not overlap at all
            # with the coverage on this dimension
            # (or at all). No overlap - no need
            # to check the other dimensions.
            if lowSlice >= highCover or highSlice <= lowCover:
                state = OVERLAP_NONE
                break

            # There is some overlap between the
            # slice and coverage on this dimension
            # - check the other dimensions.
            state = OVERLAP_SOME
978

979
980
981
982
983
984
985
986
        overlapStates[i] = state

    if   np.any(overlapStates == OVERLAP_SOME): return OVERLAP_SOME
    elif np.all(overlapStates == OVERLAP_NONE): return OVERLAP_NONE
    elif np.all(overlapStates == OVERLAP_ALL):  return OVERLAP_ALL


def sliceCovered(slices, coverage):
987
988
    """Returns ``True`` if the portion of the image data calculated by
    the given ``slices` has already been calculated, ``False`` otherwise.
989

990
991
992
993
994
995
996
    :arg slices:    A sequence of (low, high) index pairs, assumed to cover
                    all image dimensions.
    :arg coverage:  A ``numpy`` array of shape ``(2, nd, nv)`` (where ``nd``
                    is the number of dimensions being covered, and ``nv`` is
                    the number of volumes (or vectors/slices) in the image,
                    which contains the (low, high) index pairs describing
                    the current image coverage.
997
998
    """

999
1000
    numDims         = coverage.shape[1]
    lowVol, highVol = slices[numDims]
For faster browsing, not all history is shown. View entire blame