Skip to content
Snippets Groups Projects
Commit 9aa58f77 authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

Fixes/optimisations to atlas query tests

parent 6d0e6f8b
No related branches found
No related tags found
No related merge requests found
...@@ -37,14 +37,20 @@ def _repeat(iterator, n): ...@@ -37,14 +37,20 @@ def _repeat(iterator, n):
_atlases = cache.Cache() _atlases = cache.Cache()
def _get_atlas(atlasID, res, summary=False): def _get_atlas(atlasID, res, summary=False):
atlasID = 'striatum-structural'
atlas = _atlases.get((atlasID, res, summary), default=None) atlas = _atlases.get((atlasID, res, summary), default=None)
if atlas is None: if atlas is None:
if summary or atlasID in ('talairach', 'striatum-structural', 'jhu-labels'):
kwargs = {}
else:
kwargs = {'loadData' : False,
'calcRange' : False,
'indexed' : True}
atlas = fslatlases.loadAtlas(atlasID, atlas = fslatlases.loadAtlas(atlasID,
loadSummary=summary, loadSummary=summary,
resolution=res, resolution=res,
loadData=False, **kwargs)
calcRange=False,
indexed=True)
_atlases.put((atlasID, res, summary), atlas) _atlases.put((atlasID, res, summary), atlas)
return atlas return atlas
...@@ -268,12 +274,13 @@ def _eval_mask_query(atlas, query, qtype, qin): ...@@ -268,12 +274,13 @@ def _eval_mask_query(atlas, query, qtype, qin):
res = atlas.pixdim[0] res = atlas.pixdim[0]
if maskres == res: if maskres == res:
mask = mask[:] rmask = mask[:]
else: else:
mask = mask.resample(atlas.shape[:3], dtype=np.float32, order=0)[0] rmask = mask.resample(atlas.shape[:3], dtype=np.float32, order=0)[0]
mask = np.array(mask, dtype=np.bool) rmask = np.array(rmask, dtype=np.bool)
@profile
def evalLabel(): def evalLabel():
if qin == 'out': if qin == 'out':
...@@ -283,17 +290,18 @@ def _eval_mask_query(atlas, query, qtype, qin): ...@@ -283,17 +290,18 @@ def _eval_mask_query(atlas, query, qtype, qin):
if qin == 'in': if qin == 'in':
voxels = np.array(np.where(mask)).T voxels = np.array(np.where(rmask)).T
valcounts = defaultdict(lambda : 0.0) maxval = int(atlas[:].max())
valcounts = np.zeros((maxval + 1, ))
nvoxels = voxels.shape[0] nvoxels = voxels.shape[0]
for x, y, z in voxels: for x, y, z in voxels:
x, y, z = [int(v) for v in [x, y, z]] x, y, z = [int(v) for v in [x, y, z]]
valcounts[atlas[x, y, z]] += 1.0 valcounts[int(atlas[x, y, z])] += 1.0
# make sure the values are sorted # make sure the values are sorted
# according to their atlas ordering # according to their atlas ordering
expvals = list(valcounts.keys()) expvals = np.where(valcounts > 0)[0]
explabels = [] explabels = []
# There may be more values in an image # There may be more values in an image
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment