From 411a1f4940e4187f030ce19f8a48baa039cf2fb7 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Fri, 16 Apr 2021 17:57:02 +0100
Subject: [PATCH] RF: hasExt/addExt accept Path objects

---
 fsl/utils/path.py | 98 +++++++++++++++++++++++++----------------------
 1 file changed, 53 insertions(+), 45 deletions(-)

diff --git a/fsl/utils/path.py b/fsl/utils/path.py
index 9ffc15c9..09546f4b 100644
--- a/fsl/utils/path.py
+++ b/fsl/utils/path.py
@@ -47,7 +47,6 @@ class PathError(Exception):
     """``Exception`` class raised by the functions defined in this module
     when something goes wrong.
     """
-    pass
 
 
 def deepest(path, suffixes):
@@ -58,12 +57,12 @@ def deepest(path, suffixes):
 
     path = path.strip()
 
-    if path == op.sep or path == '':
+    if path in (op.sep, ''):
         return None
 
     path = path.rstrip(op.sep)
 
-    if any([path.endswith(s) for s in suffixes]):
+    if any(path.endswith(s) for s in suffixes):
         return path
 
     return deepest(op.dirname(path), suffixes)
@@ -87,7 +86,7 @@ def shallowest(path, suffixes):
     if parent is not None:
         return parent
 
-    if any([path.endswith(s) for s in suffixes]):
+    if any(path.endswith(s) for s in suffixes):
         return path
 
     return None
@@ -107,19 +106,23 @@ def allFiles(root):
     return files
 
 
-def hasExt(path, allowedExts):
+def hasExt(path        : PathLike,
+           allowedExts : Sequence[str]) -> bool:
     """Convenience function which returns ``True`` if the given ``path``
     ends with any of the given ``allowedExts``, ``False`` otherwise.
     """
-    return any([path.endswith(e) for e in allowedExts])
-
-
-def addExt(prefix,
-           allowedExts=None,
-           mustExist=True,
-           defaultExt=None,
-           fileGroups=None,
-           unambiguous=True):
+    path = str(path)
+    return any(path.endswith(e) for e in allowedExts)
+
+
+def addExt(
+        prefix      : PathLike,
+        allowedExts : Sequence[str]           = None,
+        mustExist   : bool                    = True,
+        defaultExt  : str                     = None,
+        fileGroups  : Sequence[Sequence[str]] = None,
+        unambiguous : bool                    = True
+) -> Union[Sequence[str], str]:
     """Adds a file extension to the given file ``prefix``.
 
     If ``mustExist`` is False, and the file does not already have a
@@ -154,6 +157,8 @@ def addExt(prefix,
                       containing *all* matching files is returned.
     """
 
+    prefix = str(prefix)
+
     if allowedExts is None: allowedExts = []
     if fileGroups  is None: fileGroups  = {}
 
@@ -195,7 +200,8 @@ def addExt(prefix,
 
     # If ambiguity is ok, return
     # all matching paths
-    elif not unambiguous:
+    if not unambiguous:
+
         return allPaths
 
     # Ambiguity is not ok! More than
@@ -483,7 +489,7 @@ def removeDuplicates(paths, allowedExts=None, fileGroups=None):
 
         groupFiles = getFileGroup(path, allowedExts, fileGroups)
 
-        if not any([p in unique for p in groupFiles]):
+        if not any(p in unique for p in groupFiles):
             unique.append(groupFiles[0])
 
     return unique
@@ -510,14 +516,13 @@ def uniquePrefix(path):
             break
 
         # Should never happen if path is valid
-        elif len(hits) == 0 or idx >= len(filename) - 1:
+        if len(hits) == 0 or idx >= len(filename) - 1:
             raise PathError('No unique prefix for {}'.format(filename))
 
         # Not unique - continue looping
-        else:
-            idx    += 1
-            prefix  = prefix + filename[idx]
-            hits    = [h for h in hits if h.startswith(prefix)]
+        idx    += 1
+        prefix  = prefix + filename[idx]
+        hits    = [h for h in hits if h.startswith(prefix)]
 
     return prefix
 
@@ -543,54 +548,56 @@ def commonBase(paths):
 
         last = base
 
-        if all([p.startswith(base) for p in paths]):
+        if all(p.startswith(base) for p in paths):
             return base
 
     raise PathError('No common base')
 
 
-def wslpath(winpath):
-    """
-    Convert Windows path (or a command line argument containing a Windows path)
-    to the equivalent WSL path (e.g. ``c:\\Users`` -> ``/mnt/c/Users``). Also supports
-    paths in the form ``\\wsl$\\(distro)\\users\\...``
-
-    :param winpath: Command line argument which may (or may not) contain a Windows path. It is assumed to be
-                    either of the form <windows path> or --<arg>=<windows path>. Note that we don't need to
-                    handle --arg <windows path> or -a <windows path> since in these cases the argument
-                    and the path will be parsed as separate entities.
-    :return: If ``winpath`` matches a Windows path, the converted argument (including the --<arg>= portion).
-                Otherwise returns ``winpath`` unchanged.
+def wslpath(path):
+    """Convert Windows path (or a command line argument containing a Windows
+    path) to the equivalent WSL path (e.g. ``c:\\Users`` -> ``/mnt/c/Users``).
+    Also supports paths in the form ``\\wsl$\\(distro)\\users\\...``
+
+    :param winpath: Command line argument which may (or may not) contain a
+                    Windows path. It is assumed to be either of the form
+                    <windows path> or --<arg>=<windows path>. Note that we
+                    don't need to handle --arg <windows path> or -a <windows
+                    path> since in these cases the argument and the path will
+                    be parsed as separate entities.
+    :return:        If ``winpath`` matches a Windows path, the converted
+                    argument (including the --<arg>= portion).  Otherwise
+                    returns ``winpath`` unchanged.
+
     """
-    match = re.match(r"^(--[\w-]+=)?\\\\wsl\$[\\\/][^\\^\/]+(.*)$", winpath)
+    match = re.match(r"^(--[\w-]+=)?\\\\wsl\$[\\\/][^\\^\/]+(.*)$", path)
     if match:
         arg, path = match.group(1, 2)
         if arg is None:
             arg = ""
         return arg + path.replace("\\", "/")
 
-    match = re.match(r"^(--[\w-]+=)?([a-zA-z]):(.+)$", winpath)
+    match = re.match(r"^(--[\w-]+=)?([a-zA-z]):(.+)$", path)
     if match:
         arg, drive, path = match.group(1, 2, 3)
         if arg is None:
             arg = ""
         return arg + "/mnt/" + drive.lower() + path.replace("\\", "/")
 
-    return winpath
+    return path
 
 
-def winpath(wslpath):
-    """
-    Convert a WSL-local filepath (for example ``/usr/local/fsl/``) into a path that can be used from
-    Windows.
+def winpath(path):
+    """Convert a WSL-local filepath (for example ``/usr/local/fsl/``) into a
+    path that can be used from Windows.
 
     If ``self.fslwsl`` is ``False``, simply returns ``wslpath`` unmodified
     Otherwise, uses ``FSLDIR`` to deduce the WSL distro in use for FSL.
-    This requires WSL2 which supports the ``\\wsl$\`` network path.
+    This requires WSL2 which supports the ``\\wsl$\\`` network path.
     wslpath is assumed to be an absolute path.
     """
     if not platform.fslwsl:
-        return wslpath
+        return path
     else:
         match = re.match(r"^\\\\wsl\$\\([^\\]+).*$", platform.fsldir)
         if match:
@@ -599,6 +606,7 @@ def winpath(wslpath):
             distro = None
 
         if not distro:
-            raise RuntimeError("Could not identify WSL installation from FSLDIR (%s)" % platform.fsldir)
+            raise RuntimeError('Could not identify WSL installation from '
+                               'FSLDIR (%s)' % platform.fsldir)
 
-        return "\\\\wsl$\\" + distro + wslpath.replace("/", "\\")
+        return "\\\\wsl$\\" + distro + path.replace("/", "\\")
-- 
GitLab