diff --git a/tests/test_atlases_query.py b/tests/test_atlases_query.py
index bc492cb2729758d9931ddf49b67bace26c8e0180..1be536ad2226751eeb9dca27cf41eccae8939506 100644
--- a/tests/test_atlases_query.py
+++ b/tests/test_atlases_query.py
@@ -37,14 +37,20 @@ def _repeat(iterator, n):
 
 _atlases = cache.Cache()
 def _get_atlas(atlasID, res, summary=False):
+    atlasID = 'striatum-structural'
     atlas = _atlases.get((atlasID, res, summary), default=None)
     if atlas is None:
+        if summary or atlasID in ('talairach', 'striatum-structural', 'jhu-labels'):
+            kwargs = {}
+        else:
+            kwargs = {'loadData'  : False,
+                      'calcRange' : False,
+                      'indexed'   : True}
+
         atlas = fslatlases.loadAtlas(atlasID,
                                      loadSummary=summary,
                                      resolution=res,
-                                     loadData=False,
-                                     calcRange=False,
-                                     indexed=True)
+                                     **kwargs)
         _atlases.put((atlasID, res, summary), atlas)
 
     return atlas
@@ -268,12 +274,13 @@ def _eval_mask_query(atlas, query, qtype, qin):
     res     = atlas.pixdim[0]
 
     if maskres == res:
-        mask = mask[:]
+        rmask = mask[:]
     else:
-        mask = mask.resample(atlas.shape[:3], dtype=np.float32, order=0)[0]
+        rmask = mask.resample(atlas.shape[:3], dtype=np.float32, order=0)[0]
 
-    mask = np.array(mask, dtype=np.bool)
+    rmask = np.array(rmask, dtype=np.bool)
 
+    @profile
     def evalLabel():
 
         if qin == 'out':
@@ -283,17 +290,18 @@ def _eval_mask_query(atlas, query, qtype, qin):
 
         if qin == 'in':
 
-            voxels    = np.array(np.where(mask)).T
-            valcounts = defaultdict(lambda : 0.0)
+            voxels    = np.array(np.where(rmask)).T
+            maxval    = int(atlas[:].max())
+            valcounts = np.zeros((maxval + 1, ))
             nvoxels   = voxels.shape[0]
 
             for x, y, z in voxels:
                 x, y, z = [int(v) for v in [x, y, z]]
-                valcounts[atlas[x, y, z]] += 1.0
+                valcounts[int(atlas[x, y, z])] += 1.0
 
             # make sure the values are sorted
             # according to their atlas ordering
-            expvals   = list(valcounts.keys())
+            expvals   = np.where(valcounts > 0)[0]
             explabels = []
 
             # There may be more values in an image