Commit bcd018e7 authored by Paul McCarthy's avatar Paul McCarthy 🚵
Browse files

Merge branch 'master' into 'master'

Atlases and resampling

See merge request !36
parents ff91db5c 1cc98616
Pipeline #1008 canceled with stages
in 1 minute and 40 seconds
......@@ -623,23 +623,60 @@ class LabelAtlas(Atlas):
Atlas.__init__(self, atlasDesc, resolution, True)
def label(self, worldLoc):
"""Looks up and returns the label of the region at the given world
location, or ``None`` if the location is out of bounds.
def label(self, location, *args, **kwargs):
"""Looks up and returns the label of the region at the given
location.
:arg location: Can be one of the following:
- A sequence of three values, interpreted as
atlas coordinates. In this case, :meth:`coordLabel`
is called.
- An :class:`.Image` which is interpreted as a
weighted mask. In this case, :meth:`maskLabel` is
called.
All other arguments are passed through to the :meth:`coordLabel` or
:meth:`maskLabel` methods.
:returns: The return value of either :meth:`coordLabel` or
:meth:`maskLabel`.
"""
if isinstance(location, fslimage.Image):
return self.maskLabel(location, *args, **kwargs)
else:
return self.coordLabel(location, *args, **kwargs)
def coordLabel(self, loc, voxel=False):
"""Looks up and returns the label at the given location.
:arg loc: A sequence of three values, interpreted as atlas
coordinates. In this case, :meth:`coordLabel` is called.
:arg voxel: Defaults to ``False``. If ``True``, the ``location``
is interpreted as voxel coordinates.
:returns: The label at the given coordinates, or ``None`` if the
coordinates are out of bounds.
"""
voxelLoc = transform.transform([worldLoc], self.worldToVoxMat)[0]
voxelLoc = [int(v) for v in voxelLoc.round()]
if not voxel:
loc = transform.transform([loc], self.worldToVoxMat)[0]
loc = [int(v) for v in loc.round()]
if voxelLoc[0] < 0 or \
voxelLoc[1] < 0 or \
voxelLoc[2] < 0 or \
voxelLoc[0] >= self.shape[0] or \
voxelLoc[1] >= self.shape[1] or \
voxelLoc[2] >= self.shape[2]:
if loc[0] < 0 or \
loc[1] < 0 or \
loc[2] < 0 or \
loc[0] >= self.shape[0] or \
loc[1] >= self.shape[1] or \
loc[2] >= self.shape[2]:
return None
val = self[voxelLoc[0], voxelLoc[1], voxelLoc[2]]
val = self[loc[0], loc[1], loc[2]]
if self.desc.atlasType == 'label':
return val
......@@ -648,6 +685,60 @@ class LabelAtlas(Atlas):
return val - 1
def maskLabel(self, mask):
"""Looks up and returns the proportions of all regions that are present
in the given ``mask``.
:arg mask: A 3D :class:`.Image`` which is interpreted as a weighted
mask. If the ``mask`` shape does not match that of this
``LabelAtlas``, it is resampled using
:meth:`.Image.resample`, with linear interpolation.
:returns: A tuple containing:
- A sequence of all labels which are present in the mask
- A sequence containing the proportion, within the mask,
of each present label. The proportions are returned as
values between 0 and 100.
"""
# Make sure that the mask has the same
# number of voxels as the atlas image.
# Use nearest neighbour interpolation
# for resampling, as it is most likely
# that the mask is binary.
mask = mask.resample(self.shape[:3], dtype=np.float32, order=0)[0]
boolmask = mask > 0
fslimage.Image(mask, xform=self.voxToWorldMat).save('blag.nii.gz')
# Extract the labels that are in
# the mask, and their corresponding
# mask weights
vals = self[boolmask]
weights = mask[boolmask]
weightsum = weights.sum()
labels = np.unique(vals)
props = []
for label in labels:
# Figure out the number of all voxels
# in the mask with this label, weighted
# by the mask.
prop = weights[vals == label].sum()
# Normalise it to be a proportion
# of all voxels in the mask. We
# multiply by 100 because the FSL
# probabilistic atlases store their
# probabilities as percentages.
props.append(100 * prop / weightsum)
return labels, props
class ProbabilisticAtlas(Atlas):
"""A 4D atlas which contains one volume for each region.
......@@ -666,28 +757,101 @@ class ProbabilisticAtlas(Atlas):
Atlas.__init__(self, atlasDesc, resolution, False)
def proportions(self, worldLoc):
def proportions(self, location, *args, **kwargs):
"""Looks up and returns the proportions of of all regions at the given
location.
:arg location: Can be one of the following:
- A sequence of three values, interpreted as atlas
coordinates. In this case, :meth:`coordProportions`
is called.
- An :class:`.Image` which is interpreted as a
weighted mask. In this case, :meth:`maskProportions`
is called.
All other arguments are passed through to the :meth:`coordProportions`
or :meth:`maskProportions` methods.
:returns: The return value of either :meth:`coordProportions` or
:meth:`maskProportions`.
"""
if isinstance(location, fslimage.Image):
return self.maskProportions(location, *args, **kwargs)
else:
return self.coordProportions(location, *args, **kwargs)
def coordProportions(self, loc, voxel=False):
"""Looks up the region probabilities for the given location.
:arg worldLoc: Location in the world coordinate system.
:arg loc: A sequence of three values, interpreted as atlas
world or voxel coordinates.
:arg voxel: Defaults to ``False``. If ``True``, the ``loc``
argument is interpreted as voxel coordinates.
:returns: a list of values, one per region, which represent
the probability of each region for the specified
location. Returns an empty list if the given
location is out of bounds.
"""
voxelLoc = transform.transform([worldLoc], self.worldToVoxMat)[0]
voxelLoc = [int(v) for v in voxelLoc.round()]
if voxelLoc[0] < 0 or \
voxelLoc[1] < 0 or \
voxelLoc[2] < 0 or \
voxelLoc[0] >= self.shape[0] or \
voxelLoc[1] >= self.shape[1] or \
voxelLoc[2] >= self.shape[2]:
if not voxel:
loc = transform.transform([loc], self.worldToVoxMat)[0]
loc = [int(v) for v in loc.round()]
if loc[0] < 0 or \
loc[1] < 0 or \
loc[2] < 0 or \
loc[0] >= self.shape[0] or \
loc[1] >= self.shape[1] or \
loc[2] >= self.shape[2]:
return []
return self[voxelLoc[0], voxelLoc[1], voxelLoc[2], :]
return self[loc[0], loc[1], loc[2], :]
def maskProportions(self, mask):
"""Looks up the probabilities of all regions in the given ``mask``.
:arg mask: A 3D :class:`.Image`` which is interpreted as a weighted
mask. If the ``mask`` shape does not match that of this
``ProbabilisticAtlas``, it is resampled using
:meth:`.Image.resample`, with linear interpolation.
:returns: A tuple containing:
- A sequence of all labels which are present in the mask
- A sequence containing the proportion, within the mask,
of each present label. The proportions are returned as
values between 0 and 100.
"""
labels = []
props = []
# Make sure that the mask has the same
# number of voxels as the atlas image
mask = mask.resample(self.shape[:3], dtype=np.float32, order=0)[0]
boolmask = mask > 0
weights = mask[boolmask]
weightsum = weights.sum()
for label in range(self.shape[3]):
vals = self[..., label]
vals = vals[boolmask] * weights
prop = vals.sum() / weightsum
if not np.isclose(prop, 0):
labels.append(label)
props .append(prop)
return labels, props
registry = AtlasRegistry()
......
......@@ -40,6 +40,7 @@ import logging
import six
import deprecation
import numpy as np
import scipy.ndimage as ndimage
import nibabel as nib
import nibabel.fileslice as fileslice
......@@ -787,6 +788,15 @@ class Image(Nifti):
indexed = False
threaded = False
# Take a copy of the header if one has
# been provided
#
# NOTE: Nifti extensions are copied by
# reference, which may cause issues in
# the future.
if header is not None:
header = header.copy()
# The image parameter may be the name of an image file
if isinstance(image, six.string_types):
......@@ -825,6 +835,12 @@ class Image(Nifti):
if header is None:
ctr = nib.nifti1.Nifti1Image
# make sure that the data type is correct,
# in case this header was passed in from
# a different image
if header is not None:
header.set_data_dtype(image.dtype)
# But if a nibabel header has been provided,
# we use the corresponding image type
if isinstance(header, nib.nifti2.Nifti2Header):
......@@ -1115,6 +1131,93 @@ class Image(Nifti):
self.notify(topic='saveState')
def resample(self,
newShape,
sliceobj=None,
dtype=None,
order=1,
smooth=True):
"""Returns a copy of the data in this ``Image``, resampled to the
specified ``shape``.
:arg newShape: Desired shape. May containg floating point values,
in which case the resampled image will have shape
``round(newShape)``, but the voxel sizes will
have scales ``self.shape / newShape``.
:arg sliceobj: Slice into this ``Image``. If ``None``, the whole
image is resampled, and it is assumed that it has the
same number of dimensions as ``shape``.
:arg dtype: ``numpy`` data type of the resampled data. If ``None``,
the :meth:`dtype` of this ``Image`` is used.
:arg order: Spline interpolation order, passed through to the
``scipy.ndimage.affine_transform`` function - ``0``
corresponds to nearest neighbour interpolation, ``1``
(the default) to linear interpolation, and ``3`` to
cubic interpolation.
:arg smooth: If ``True`` (the default), the data is smoothed before
being resampled, but only along axes which are being
down-sampled (i.e. where
``newShape[i] < self.shape[i]``).
:returns: A tuple containing:
- A ``numpy`` array of shape ``shape``, containing an
interpolated copy of the data in this ``Image``.
- A ``numpy`` array of shape ``(4, 4)``, containing the
adjusted voxel-to-world transformation for the resampled
data.
"""
if sliceobj is None: sliceobj = slice(None)
if dtype is None: dtype = self.dtype
data = self[sliceobj]
data = np.array(data, dtype=dtype, copy=False)
oldShape = np.array(data.shape, dtype=np.float)
newShape = np.array(newShape, dtype=np.float)
if not np.all(np.isclose(oldShape, newShape)):
ratio = oldShape / newShape
newShape = np.array(np.round(newShape), dtype=np.int)
scale = transform.scaleOffsetXform(ratio, 0)
# If interpolating and smoothing, we apply a
# gaussian filter along axes with a resampling
# ratio greater than 1.1. We do this so that
# interpolation has an effect when down-sampling
# to a resolution where the voxel centres are
# aligned (as otherwise any interpolation regime
# will be equivalent to nearest neighbour). This
# more-or-less mimics the behaviour of FLIRT.
if order > 0 and smooth:
sigma = np.array(ratio)
sigma[ratio < 1.1] = 0
sigma[ratio >= 1.1] *= 0.425
data = ndimage.gaussian_filter(data, sigma)
data = ndimage.affine_transform(data,
scale[:3, :3],
output_shape=newShape,
order=order)
# Construct an affine transform which
# puts the resampled image into the
# same world coordinate system as this
# image.
xform = transform.concat(self.voxToWorldMat, scale)
else:
xform = self.voxToWorldMat
return data, xform
def __getitem__(self, sliceobj):
"""Access the image data with the specified ``sliceobj``.
......
......@@ -31,7 +31,7 @@ is compatible with PEP 440 (https://www.python.org/dev/peps/pep-0440/):
which primarily involve bug-fixes and minor changes.
The sole exception to the above convention are evelopment versions, which end
The sole exception to the above convention are development versions, which end
in ``'.dev'``.
"""
......@@ -56,7 +56,7 @@ def parseVersionString(versionString):
components = versionString.split('.')
# Truncate after three elements -
# a development (unreleased0 version
# a development (unreleased version
# number will end with '.dev', but
# we ignore this for the purposes of
# comparison.
......@@ -89,7 +89,7 @@ def compareVersions(v1, v2, ignorePoint=False):
- -1 if ``v1`` < ``v2`` (i.e. ``v1`` is older than ``v2``)
- 0 if ``v1`` == ``v2``
- 0 if ``v1`` > ``v2``
- 1 if ``v1`` > ``v2``
"""
v1 = parseVersionString(v1)
......
......@@ -6,9 +6,10 @@
#
import os
import os.path as op
import numpy as np
import os
import os.path as op
import itertools as it
import numpy as np
import mock
import pytest
......@@ -18,12 +19,14 @@ import fsl.data.atlases as atlases
import fsl.data.image as fslimage
datadir = op.join(op.dirname(__file__), 'testdata')
def setup_module():
if os.environ.get('FSLDIR', None) is None:
raise Exception('FSLDIR is not set - atlas tests cannot be run')
dummy_atlas_desc = """<?xml version="1.0" encoding="ISO-8859-1"?>
<atlas version="1.0">
......@@ -52,7 +55,7 @@ def _make_dummy_atlas(savedir, name, shortName, filename):
data[6, 6, 6] = 2
img = fslimage.Image(data, xform=np.eye(4))
os.makedirs(mladir)
img.save(mlaimgfile)
......@@ -69,12 +72,12 @@ def _make_dummy_atlas(savedir, name, shortName, filename):
def test_registry():
registry = atlases.registry
registry.rescanAtlases()
assert len(registry.listAtlases()) > 0
assert registry.hasAtlas('harvardoxford-cortical')
adesc = registry.getAtlasDescription('harvardoxford-cortical')
assert isinstance(adesc, atlases.AtlasDescription)
......@@ -87,7 +90,7 @@ def test_AtlasDescription():
registry.rescanAtlases()
tal = registry.getAtlasDescription('talairach')
cort = registry.getAtlasDescription('harvardoxford-cortical')
cort = registry.getAtlasDescription('harvardoxford-cortical')
assert tal.atlasID == 'talairach'
......@@ -107,7 +110,7 @@ def test_AtlasDescription():
lbl.x
lbl.y
lbl.z
assert cort.atlasID == 'harvardoxford-cortical'
assert cort.name == 'Harvard-Oxford Cortical Structural Atlas'
assert cort.specPath
......@@ -124,7 +127,7 @@ def test_AtlasDescription():
lbl.index
lbl.x
lbl.y
lbl.z
lbl.z
with pytest.raises(Exception):
registry.getAtlasDescription('non-existent-atlas')
......@@ -151,7 +154,7 @@ def test_add_remove_atlas():
assert r is reg
assert topic == 'remove'
assert val.atlasID == 'mla'
removed[0] = True
removed[0] = True
xmlfile = _make_dummy_atlas(testdir, 'My Little Atlas', 'MLA', 'MyLittleAtlas')
......@@ -183,7 +186,7 @@ def test_extra_atlases():
badspec = op.join(testdir, 'badSpec.xml')
with open(badspec, 'wt') as f:
f.write('Bwahahahah!')
extraAtlases = ':'.join([
'myatlas1={}'.format(atlas1spec),
'myatlas2={}'.format(atlas2spec),
......@@ -217,7 +220,7 @@ def test_load_atlas():
assert isinstance(lblatlas, atlases.LabelAtlas)
def test_label_atlas():
def test_label_atlas_coord():
reg = atlases.registry
reg.rescanAtlases()
......@@ -233,7 +236,8 @@ def test_label_atlas():
([ 6, -78, 50], 862)]
for coords, expected in taltests:
assert atlas.label(coords) == expected
assert atlas.label( coords) == expected
assert atlas.coordLabel(coords) == expected
assert atlas.label([ 999, 999, 999]) is None
assert atlas.label([-999, -999, -999]) is None
......@@ -248,13 +252,14 @@ def test_label_atlas():
([ 54, -44, -27], 15)]
for coords, expected in hoctests:
assert atlas.label(coords) == expected
assert atlas.label( coords) == expected
assert atlas.coordLabel(coords) == expected
assert atlas.label([ 999, 999, 999]) is None
assert atlas.label([-999, -999, -999]) is None
def test_prob_atlas():
def test_prob_atlas_coord():
reg = atlases.registry
reg.rescanAtlases()
......@@ -273,13 +278,73 @@ def test_prob_atlas():
([-29, -42, -11], [(34, 21), (35, 23), (37, 26), (38, 24)])]
for coords, expected in hoctests:
result = atlas.proportions(coords)
expidxs = [e[0] for e in expected]
for i in range(len(result)):
if i not in expidxs:
assert result[i] == 0
for expidx, expprob in expected:
assert result[expidx] == expprob
def test_prob_atlas_mask():
# test the maskProportions function
reg = atlases.registry
reg.rescanAtlases()
hotests = [
'test_atlases_ho_mask_1mm',
'test_atlases_ho_mask_2mm'
]
resolutions = [1, 2]
for prefix, res in it.product(hotests, resolutions):
maskfile = op.join(datadir, '{}.nii.gz' .format(prefix))
resultsfile = op.join(datadir, '{}_res{}.txt'.format(prefix, res))
atlas = reg.loadAtlas('harvardoxford-cortical', resolution=res)
mask = fslimage.Image(maskfile)
labels, props = atlas.maskProportions(mask)
labels2, props2 = atlas.proportions(mask)
expected = np.loadtxt(resultsfile)
explabels = expected[:, 0]
expprops = expected[:, 1]
assert np.all(np.isclose(labels, labels2))
assert np.all(np.isclose(props, props2))
assert np.all(np.isclose(labels, explabels))
assert np.all(np.isclose(props, expprops))
def test_label_atlas_mask():
# Test the maskLabel function
reg = atlases.registry
reg.rescanAtlases()
taltests = [
'test_atlases_tal_mask_1mm',
'test_atlases_tal_mask_2mm'
]
resolutions = [1, 2]
for prefix, res in it.product(taltests, resolutions):
maskfile = op.join(datadir, '{}.nii.gz' .format(prefix))
resultsfile = op.join(datadir, '{}_res{}.txt'.format(prefix, res))
atlas = reg.loadAtlas('talairach', resolution=res)
mask = fslimage.Image(maskfile)
labels, props = atlas.maskLabel(mask)
labels2, props2 = atlas.label(mask)
expected = np.loadtxt(resultsfile)
explabels = expected[:, 0]
expprops = expected[:, 1]
assert np.all(np.isclose(labels, labels2))
assert np.all(np.isclose(props, props2))
assert np.all(np.isclose(labels, explabels))
assert np.all(np.isclose(props, expprops))
......@@ -25,6 +25,7 @@ import fsl.data.constants as constants
import fsl.data.image as fslimage
import fsl.data.imagewrapper as imagewrapper
import fsl.utils.path as fslpath
import fsl.utils.transform as transform
from . import make_random_image
from . import make_dummy_file
......@@ -1000,3 +1001,83 @@ def _test_Image_save(imgtype):
finally:
shutil.rmtree(testdir)
def test_image_resample(seed):