diff --git a/fsl/fslview/displaycontext.py b/fsl/fslview/displaycontext.py index fb2f3ade8d0966a1fe8fdb148b6dca2d587b074a..1bd0ff370b83e4626140330bca7b50ba340a7a4c 100644 --- a/fsl/fslview/displaycontext.py +++ b/fsl/fslview/displaycontext.py @@ -271,6 +271,21 @@ class ImageDisplay(props.SyncableHasProperties): 'imageType']) + def getDisplayBounds(self): + """Calculates and returns the min/max values of a 3D bounding box, + in the display coordinate system, which is big enough to contain + the image associated with this :class:`ImageDisplay` instance. + + The coordinate system in which the bounding box is defined is + determined by the current value of the :attr:`transform` property. + + A tuple containing two values is returned, with the first value + a sequence of three low bounds, and the second value a sequence + of three high bounds. + """ + return transform.axisBounds(self.image.shape[:3], self.voxToDisplayMat) + + def _transformChanged(self, *a): """Called when the :attr:`transform` property is changed. @@ -615,14 +630,12 @@ class DisplayContext(props.SyncableHasProperties): for img in self._imageList.images: display = self._imageDisplays[img] - xform = display.voxToDisplayMat + lo, hi = display.getDisplayBounds() for ax in range(3): - lo, hi = transform.axisBounds(img.shape[:3], xform, ax) - - if lo < minBounds[ax]: minBounds[ax] = lo - if hi > maxBounds[ax]: maxBounds[ax] = hi + if lo[ax] < minBounds[ax]: minBounds[ax] = lo[ax] + if hi[ax] > maxBounds[ax]: maxBounds[ax] = hi[ax] self.bounds[:] = [minBounds[0], maxBounds[0], minBounds[1], maxBounds[1], diff --git a/fsl/utils/transform.py b/fsl/utils/transform.py index c6fedae73d8de156354a6712521dd91210b6e455..2fc298f726cce1b61042b18ede1595a46e265088 100644 --- a/fsl/utils/transform.py +++ b/fsl/utils/transform.py @@ -23,8 +23,18 @@ def concat(x1, x2): return np.dot(x1, x2) -def axisBounds(shape, xform, axis): - """Returns the (lo, hi) bounds of the specified axis.""" +def axisBounds(shape, xform, axes=None): + """Returns the (lo, hi) bounds of the specified axis/axes.""" + + scalar = False + + if axes is None: + axes = [0, 1, 2] + + elif not isinstance(axes, collections.Iterable): + scalar = True + axes = [axes] + x, y, z = shape[:3] x -= 0.5 @@ -44,10 +54,11 @@ def axisBounds(shape, xform, axis): tx = transform(points, xform) - lo = tx[:, axis].min() - hi = tx[:, axis].max() + lo = tx[:, axes].min(axis=0) + hi = tx[:, axes].max(axis=0) - return (lo, hi) + if scalar: return (lo[0], hi[0]) + else: return (lo, hi) def axisLength(shape, xform, axis):