diff --git a/tests/__init__.py b/tests/__init__.py
index 09ffa0e26ed068bcda2525dafd6fdb3fe6422e9e..7170e9e75d485a4682fcb0b7c96019ce8d04c62e 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -9,11 +9,27 @@
 
 import            os
 import            shutil
+import            tempfile
 import os.path as op
 import numpy   as np
 import nibabel as nib
 
 
+def testdir():
+    """Returnsa context manager which creates and returns a temporary
+    directory, and then deletes it on exit.
+    """
+    class ctx(object):
+        def __enter__(self):
+            self.testdir = tempfile.mkdtemp()
+            return self.testdir
+
+        def __exit__(self, *a, **kwa):
+            shutil.rmtree(self.testdir)
+
+    return ctx()
+
+
 def make_dummy_file(path):
     """Makes a plain text file. Returns a hash of the file contents. """
     contents = '{}\n'.format(op.basename(path))
diff --git a/tests/test_atlases.py b/tests/test_atlases.py
index 7a32e0fc144bb96f6913c81014675628db51f8f8..7eec22eaf0228bdb736a8f82ff81dcc94e8d2072 100644
--- a/tests/test_atlases.py
+++ b/tests/test_atlases.py
@@ -8,12 +8,11 @@
 
 import            os
 import os.path as op
-import            shutil
-import            tempfile
 import numpy   as np
 
 import pytest
 
+import tests
 import fsl.data.atlases as atlases
 import fsl.data.image   as fslimage
 
@@ -109,43 +108,43 @@ dummy_atlas_desc = """<?xml version="1.0" encoding="ISO-8859-1"?>
 
 def test_add_remove_atlas():
 
-    testdir    = tempfile.mkdtemp()
-    mladir     = op.join(testdir, 'MLA')
-    mlaxmlfile = op.join(testdir, 'MLA.xml')
-    mlaimgfile = op.join(testdir, 'MLA', 'MyLittleAtlas.nii.gz')
- 
-    def _make_dummy_atlas():
+    with tests.testdir() as testdir:
 
-        data = np.zeros((10, 10, 10))
-        data[5, 5, 5] = 1
-        data[6, 6, 6] = 2
+        mladir     = op.join(testdir, 'MLA')
+        mlaxmlfile = op.join(testdir, 'MLA.xml')
+        mlaimgfile = op.join(testdir, 'MLA', 'MyLittleAtlas.nii.gz')
 
-        img = fslimage.Image(data, xform=np.eye(4))
+        def _make_dummy_atlas():
 
-        os.makedirs(mladir)
-        img.save(mlaimgfile)
+            data = np.zeros((10, 10, 10))
+            data[5, 5, 5] = 1
+            data[6, 6, 6] = 2
 
-        with open(mlaxmlfile, 'wt') as f:
-            f.write(dummy_atlas_desc)
+            img = fslimage.Image(data, xform=np.eye(4))
 
-    added   = [False]
-    removed = [False]
-    reg     = atlases.registry
-    reg.rescanAtlases()
+            os.makedirs(mladir)
+            img.save(mlaimgfile)
+
+            with open(mlaxmlfile, 'wt') as f:
+                f.write(dummy_atlas_desc)
+
+        added   = [False]
+        removed = [False]
+        reg     = atlases.registry
+        reg.rescanAtlases()
+
+        def atlas_added(r, topic, val):
+            assert topic == 'add'
+            assert r is reg
+            assert val.atlasID == 'mla'
+            added[0] = True
+
+        def atlas_removed(r, topic, val):
+            assert r is reg
+            assert topic == 'remove'
+            assert val.atlasID == 'mla'
+            removed[0] = True 
 
-    def atlas_added(r, topic, val):
-        assert topic == 'add'
-        assert r is reg
-        assert val.atlasID == 'mla'
-        added[0] = True
-        
-    def atlas_removed(r, topic, val):
-        assert r is reg
-        assert topic == 'remove'
-        assert val.atlasID == 'mla'
-        removed[0] = True 
-    
-    try:
         _make_dummy_atlas()
 
         reg.register('added',   atlas_added,   topic='add')
@@ -164,10 +163,6 @@ def test_add_remove_atlas():
         reg.removeAtlas('mla')
 
         assert removed[0]
-        
-
-    finally:
-        shutil.rmtree(testdir)
 
 
 def test_load_atlas():