From 33a9763ac6fe8ef4c4576b43b337e4f1ce29d495 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Tue, 27 Feb 2018 16:30:48 +0000
Subject: [PATCH] More little adjustments to assertions module

---
 fsl/utils/assertions.py | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/fsl/utils/assertions.py b/fsl/utils/assertions.py
index 35dde034d..8d6e0ba10 100644
--- a/fsl/utils/assertions.py
+++ b/fsl/utils/assertions.py
@@ -12,6 +12,7 @@
 import os.path as op
 import nibabel as nib
 
+import fsl.utils.ensure         as ensure
 import fsl.data.melodicanalysis as fslma
 
 
@@ -25,7 +26,7 @@ def assertIsNifti3D(*args):
     """Raise an exception if the specified file/s are not 3D nifti."""
     for f in args:
         assertIsNifti(f)
-        d = nib.load(f)
+        d = ensure.ensureIsImage(f)
         assert len(d.shape) == 3, \
             'incorrect shape for 3D nifti: {}:{}'.format(d.shape, f)
 
@@ -34,7 +35,7 @@ def assertIsNifti4D(*args):
     """Raise an exception if the specified file/s are not 4D nifti."""
     for f in args:
         assertIsNifti(f)
-        d = nib.load(f)
+        d = ensure.ensureIsImage(f)
         assert len(d.shape) == 4, \
             'incorrect shape for 4D nifti: {}:{}'.format(d.shape, f)
 
@@ -42,16 +43,19 @@ def assertIsNifti4D(*args):
 def assertIsNifti(*args):
     """Raise an exception if the specified file/s are not nifti."""
     for f in args:
-        assert isinstance(f, nib.nifti1.Nifti1Image) or \
-            f.endswith('.nii.gz') or f.endswith('.nii'), \
+        f = ensure.ensureIsImage(f)
+
+        # Nifti2Image derives from Nifti1Image,
+        # so we only need to test the latter.
+        assert isinstance(f, nib.nifti1.Nifti1Image), \
             'file must be a nifti (.nii or .nii.gz): {}'.format(f)
 
 
 def assertNiftiShape(shape, *args):
     """Raise an exception if the specified nifti/s are not specified shape."""
     for fname in args:
-        d = nib.load(fname)
-        assert d.shape == shape, \
+        d = ensure.ensureIsImage(fname)
+        assert tuple(d.shape) == tuple(shape), \
             'incorrect shape ({}) for nifti: {}:{}'.format(
                 shape, d.shape, fname)
 
-- 
GitLab