diff --git a/tests/test_atlases_query.py b/tests/test_atlases_query.py index bc492cb2729758d9931ddf49b67bace26c8e0180..1be536ad2226751eeb9dca27cf41eccae8939506 100644 --- a/tests/test_atlases_query.py +++ b/tests/test_atlases_query.py @@ -37,14 +37,20 @@ def _repeat(iterator, n): _atlases = cache.Cache() def _get_atlas(atlasID, res, summary=False): + atlasID = 'striatum-structural' atlas = _atlases.get((atlasID, res, summary), default=None) if atlas is None: + if summary or atlasID in ('talairach', 'striatum-structural', 'jhu-labels'): + kwargs = {} + else: + kwargs = {'loadData' : False, + 'calcRange' : False, + 'indexed' : True} + atlas = fslatlases.loadAtlas(atlasID, loadSummary=summary, resolution=res, - loadData=False, - calcRange=False, - indexed=True) + **kwargs) _atlases.put((atlasID, res, summary), atlas) return atlas @@ -268,12 +274,13 @@ def _eval_mask_query(atlas, query, qtype, qin): res = atlas.pixdim[0] if maskres == res: - mask = mask[:] + rmask = mask[:] else: - mask = mask.resample(atlas.shape[:3], dtype=np.float32, order=0)[0] + rmask = mask.resample(atlas.shape[:3], dtype=np.float32, order=0)[0] - mask = np.array(mask, dtype=np.bool) + rmask = np.array(rmask, dtype=np.bool) + @profile def evalLabel(): if qin == 'out': @@ -283,17 +290,18 @@ def _eval_mask_query(atlas, query, qtype, qin): if qin == 'in': - voxels = np.array(np.where(mask)).T - valcounts = defaultdict(lambda : 0.0) + voxels = np.array(np.where(rmask)).T + maxval = int(atlas[:].max()) + valcounts = np.zeros((maxval + 1, )) nvoxels = voxels.shape[0] for x, y, z in voxels: x, y, z = [int(v) for v in [x, y, z]] - valcounts[atlas[x, y, z]] += 1.0 + valcounts[int(atlas[x, y, z])] += 1.0 # make sure the values are sorted # according to their atlas ordering - expvals = list(valcounts.keys()) + expvals = np.where(valcounts > 0)[0] explabels = [] # There may be more values in an image