From 6ec19c5eb4332f53a1718a55acb22917cd0bdc04 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Fri, 17 Dec 2021 17:33:34 +0000
Subject: [PATCH] ENH: update_fsl_package script working. One more task to do

---
 ...date_fsl_package.py => update_fsl_package} | 287 +++++++++++++-----
 1 file changed, 215 insertions(+), 72 deletions(-)
 rename share/fsl/sbin/{update_fsl_package.py => update_fsl_package} (58%)

diff --git a/share/fsl/sbin/update_fsl_package.py b/share/fsl/sbin/update_fsl_package
similarity index 58%
rename from share/fsl/sbin/update_fsl_package.py
rename to share/fsl/sbin/update_fsl_package
index 743cd36..5886ae1 100755
--- a/share/fsl/sbin/update_fsl_package.py
+++ b/share/fsl/sbin/update_fsl_package
@@ -15,9 +15,9 @@ The script performs the following steps:
     the requested packages.
 
  4. Runs "conda install" to install/update the packages.
-
 """
 # Note: Conda does have a Python API:
+#
 #   https://docs.conda.io/projects/conda/en/latest/api/index.html
 #
 # But at the time of writing this script, most of its functionality is marked
@@ -29,11 +29,14 @@ import                        argparse
 import                        bisect
 import                        dataclasses
 import                        json
+import                        logging
 import                        os
+import                        platform
 import                        shlex
 import                        string
 import                        sys
 import functools       as     ft
+import os.path         as     op
 import subprocess      as     sp
 import urllib.parse    as     urlparse
 import urllib.request  as     urlrequest
@@ -42,10 +45,35 @@ from   collections     import defaultdict
 from typing import Dict, List, Union, Tuple, Optional, Sequence
 
 
+log = logging.getLogger(__name__)
+
+
 PUBLIC_FSL_CHANNEL   = 'https://fsl.fmrib.ox.ac.uk/fsldownloads/fslconda/public/'
 INTERNAL_FSL_CHANNEL = 'https://fsl.fmrib.ox.ac.uk/fsldownloads/fslconda/internal/'
 
 
+def conda(cmd : str) -> str:
+    """Runs the conda command with the given arguments, returning its standard
+    output as a string.
+    """
+    fsldir   = os.environ['FSLDIR']
+    condabin = op.join(fsldir, 'bin', 'conda')
+
+    log.debug(f'Running {condabin} {cmd}')
+
+    cmd    = [condabin] + shlex.split(cmd)
+    result = sp.run(cmd, check=False, capture_output=True, text=True)
+
+    log.debug(f'Exit code: {result.returncode}')
+
+    if result.returncode != 0:
+        log.debug(f'Standard output:\n{result.stdout}')
+        log.debug(f'Standard error:\n{result.stderr}')
+        raise RuntimeError('Command returned error: {cmd}')
+
+    return result.stdout
+
+
 def identify_platform() -> str:
     """Figures out what platform we are running on. Returns a platform
     identifier string - one of:
@@ -71,9 +99,47 @@ def identify_platform() -> str:
                         'unsupported! Supported platforms: {}'.format(
                             system, cpu, supported))
 
+    log.debug(f'Detected platform: {platforms[key]}')
+
     return platforms[key]
 
 
+@ft.total_ordering
+class Version:
+    """Class for parsing/comparing version strings. """
+    def __init__(self, verstr):
+        self.verstr = verstr
+
+        parts = []
+
+        verstr = verstr.lower()
+        if verstr.startswith('v'):
+            verstr = verstr[1:]
+
+        for part in verstr.split('.'):
+
+            # FSL development releases may have ".postN"
+            if part.startswith('post'):
+                part = part[4:]
+            # FSL development releases may have ".devYYYYMMDD<githash>"
+            if part.startswith('dev'):
+                for end, char in enumerate(part[3:], 3):
+                    if char not in string.digits:
+                        break
+                part = part[3:end]
+            try:
+                parts.append(int(part))
+            except Exception:
+                break
+        self.parsed_version = tuple(parts)
+
+    def __lt__(self, other):
+        return self.parsed_version < other.parsed_version
+
+    def __eq__(self, other):
+        return self.parsed_version == other.parsed_version
+
+
 def http_request(url      : str,
                  username : str = None,
                  password : str = None):
@@ -88,7 +154,7 @@ def http_request(url      : str,
         opener.open(url)
         urlrequest.install_opener(opener)
 
-    print(f'Downloading {url} ...')
+    log.debug(f'Downloading {url} ...')
 
     request  = urlrequest.Request(url, method='GET')
     response = urlrequest.urlopen(request)
@@ -100,6 +166,30 @@ def http_request(url      : str,
     return payload
 
 
+@ft.lru_cache
+def query_installed_packages() -> Dict[str, str]:
+    """Uses conda to find out the versions of all packages installed in
+    $FSLDIR, which are sourced from the FSL conda channels.
+    """
+
+    channels = [PUBLIC_FSL_CHANNEL  .rstrip('/'),
+                INTERNAL_FSL_CHANNEL.rstrip('/')]
+
+    # conda info returns a list of dicts,
+    # one per package. We re-arrange this
+    # into a dict of {pkgname : version}
+    # mappings.
+    fsldir = os.environ['FSLDIR']
+    info   = json.loads(conda(f'list -p {fsldir} --json'))
+    pkgs   = {}
+
+    for pkg in info:
+        if pkg['base_url'].rstrip() in channels:
+            pkgs[pkg['name']] = pkg['version']
+
+    return pkgs
+
+
 @ft.total_ordering
 @dataclasses.dataclass
 class Package:
@@ -133,44 +223,25 @@ class Package:
 
 
     @property
-    def parsed_version(self) -> Tuple[int]:
-        """Return the version as a tuple of integers. """
-        parts = []
-
-        version = self.version.lower()
-        if version.startswith('v'):
-            version = version[1:]
-
-        for part in version.split('.'):
-
-            # FSL development releases may have ".postN"
-            if part.startswith('post'):
-                part = part[4:]
-            # FSL development releases may have ".devYYYYMMDD<githash>"
-            if part.startswith('dev'):
-                for end, char in enumerate(part[3:], 3):
-                    if char not in string.digits:
-                        break
-                part = part[3:end]
-            try:
-                parts.append(int(part))
-            except Exception:
-                break
-        return tuple(parts)
+    def installed_version(self) -> str:
+        """Return the version of this package which is currently installed, or '-'
+        if not installed.
+        """
+        return query_installed_packages().get(self.name, '-')
 
 
     def __lt__(self, pkg):
         """Only valid when comparing another Package with the same name and
         platform.
         """
-        return self.parsed_version < pkg.parsed_version
+        return Version(self.version) < Version(pkg.version)
 
 
     def __eq__(self, pkg):
         """Only valid when comparing another Package with the same name and
         platform.
         """
-        return self.parsed_version == pkg.parsed_version
+        return Version(self.version) == Version(pkg.version)
 
 
 def download_channel_metadata(channel_url : str, **kwargs) -> Tuple[Dict, Dict]:
@@ -202,32 +273,27 @@ def download_channel_metadata(channel_url : str, **kwargs) -> Tuple[Dict, Dict]:
     # they are built for, and the second gives us
     # the available versions, and dependencies of
     # each package.
-    cdata = http_request(f'{channel_url}/channeldata.json', **kwargs)
-    cdata['channel_url'] = channel_url
-    pdata = {}
+    chandata = http_request(f'{channel_url}/channeldata.json', **kwargs)
+    chandata['channel_url'] = channel_url
+    platdata = {}
 
     # only consider packages
     # relevant to this platform
-    for platform in cdata['subdirs']:
+    for platform in chandata['subdirs']:
         if platform in ('noarch', thisplat):
-            purl            = f'{channel_url}/{platform}/repodata.json'
-            pdata[platform] = http_request(purl)
+            purl               = f'{channel_url}/{platform}/repodata.json'
+            platdata[platform] = http_request(purl)
 
     # Re-arrange the platform repodata to
     # {platform : {pkgname : [pkgfiles]}}
-    # dicts, to make the version retrieval
-    # below a little easier.
-    pdatadicts = []
-    for pdata in platformdata:
-        pdatadict = defaultdict(lambda : defaultdict(list))
-        pdatadicts.append(pdatadict)
-        for platform, pkgs in pdata.items():
-            for pkg in pkgs['packages'].values():
-                if pkg['name'] in pkgnames:
-                    pdatadict[platform][pkg['name']].append(pkg)
-    platformdata = pdatadicts
+    # dicts, to make lookup by name easier.
+    platdatadict = defaultdict(lambda : defaultdict(list))
+    for platform, pdata in platdata.items():
+        for pkg in pdata['packages'].values():
+            platdatadict[platform][pkg['name']].append(pkg)
+    platdata = platdatadict
 
-    return cdata, pdata
+    return chandata, platdata
 
 
 def identify_packages(channeldata : List[Tuple[Dict, Dict]],
@@ -252,74 +318,123 @@ def identify_packages(channeldata : List[Tuple[Dict, Dict]],
     for pkgname in pkgnames:
         for cdata, pdata in channeldata:
             if pkgname in cdata['packages']:
-                pkgchannels[pkgname] = (cdata['channel_url'], pdata)
+                pkgchannels[pkgname] = (cdata, pdata)
                 break
         # This package is not available
-        # todo message user
         else:
+            log.debug(f'Package {pkgname} is not available - ignoring.')
             continue
 
     # Create Package objects for every available version of
     # the requested packages. The packages dict has structure
     #
-    # { platform : { : [Package, Package, ...]}}
+    # {pkgname : [Package, Package, ...]}
     #
     # where the package lists are sorted from oldest to newest.
-    packages = defaultdict(lambda : defaultdict(list))
-    for pkgname, (curl, pdata) in pkgchannels.items():
+    packages = defaultdict(list)
+    for pkgname, (cdata, pdata) in pkgchannels.items():
+        curl = cdata['channel_url']
         for platform in pdata.keys():
             for pkgfile in pdata[platform][pkgname]:
-                version = pkgfile['version']
-                depends = pkgfile['depends']
-                pkg     = Package(pkgname, curl, platform, version, depends)
-                if pkg.development and not (development):
+                version   = pkgfile['version']
+                depends   = pkgfile['depends']
+                pkg       = Package(pkgname, curl, platform, version, depends)
+                if pkg.development and (not development):
                     continue
-                bisect.insort(packages[platform][pkgname], pkg)
+                bisect.insort(packages[pkgname], pkg)
 
     return packages
 
 
+def confirm_installation(packages : Sequence[Package], yes : bool) -> bool:
+    """Asks the user for confirmation, before installing/updating the requested
+    packages.
+    """
+    rows = [('Package name', 'Currently installed', 'Updating to'),
+            ('------------', '-------------------', '-----------')]
+
+    for pkg in packages:
+        rows.append((pkg.name, pkg.installed_version, pkg.version))
+
+    len0 = max(len(r[0]) for r in rows)
+    len1 = max(len(r[1]) for r in rows)
+    len2 = max(len(r[2]) for r in rows)
+
+    template = f'{{:{len0}}}  {{:{len1}}}  {{:{len2}}}'
+
+    print('\nThe following updates are available:\n')
+
+    for row in rows:
+        print(template.format(*row))
+
+    if yes:
+        return True
+
+    response = input('\nProceed? [Y/n]: ')
+
+    return response.strip().lower() in ('', 'y', 'yes')
+
+
+def install_packages(packages : Sequence[Package]):
+    """Calls conda to update the given collection of packages. """
+
+    fsldir   = os.environ['FSLDIR']
+    packages = [f'"{p.name}={p.version}"' for p in packages]
+    cmd      = f'install --no-deps -p {fsldir} -y ' + ' '.join(packages)
+
+    print('\nInstalling packages...')
+    conda(cmd)
+
+
 def parse_args(argv : Optional[Sequence[str]]) -> argparse.Namespace:
     """Parses command-line arguments, returning an argparse.Namespace object. """
 
     parser = argparse.ArgumentParser(
         'update_fsl_package', description='Install/update FSL packages')
-    parser.add_argument('package', nargs='+',
+    parser.add_argument('package', nargs='*',
                         help='Package[s] to install/update')
     parser.add_argument('-d', '--development', action='store_true',
                         help='Install development versions of packages if available')
     parser.add_argument('-y', '--yes', action='store_true',
                         help='Install package[s] without prompting for confirmation')
+    parser.add_argument('-a', '--all', action='store_true',
+                        help='Install/update all installed FSL packages')
 
     parser.add_argument('--internal', help=argparse.SUPPRESS)
     parser.add_argument('--username', help=argparse.SUPPRESS)
     parser.add_argument('--password', help=argparse.SUPPRESS)
+    parser.add_argument('--verbose',  help=argparse.SUPPRESS, action='store_true')
 
-    return parser.parse_args(argv)
+    args = parser.parse_args(argv)
 
+    if len(args.package) == 0 and not args.all:
+        parser.error('Specify at least one package, or use --all to update '
+                     'all installed FSL packages.')
 
-def confirm_installation(packages : Sequence[Package]):
-    pass
-
-
-def install_packages(packages : Sequence[Package]):
-    """Calls conda to update the given collection of packages. """
-
-    pkgs  = [f'"{p.name}={p.version}"' for p in pkgs]
-    conda = op.expandvars('$FSLDIR/bin/conda')
-    cmd   = f'{conda} install -n base -y ' + ' '.join(pkgs)
-
-    sp.run(shlex.split(cmd))
+    return args
 
 
 def main(argv : Sequence[str] = None):
     """Entry point. Parses command-line arguments, then installs/updates the
     specified packages.
     """
+
+    if 'FSLDIR' not in os.environ:
+        print('$FSLDIR is not set - aborting')
+        sys.exit(1)
+
     if argv is None:
         argv = sys.argv[1:]
     args = parse_args(argv)
 
+    logging.basicConfig()
+    if args.verbose: log.setLevel(logging.DEBUG)
+    else:            log.setLevel(logging.INFO)
+
+    # Download information about all
+    # available packages on the FSL
+    # conda channels.
+    print('Downloading FSL conda channel information ...')
     channeldata = [download_channel_metadata(PUBLIC_FSL_CHANNEL)]
     if args.internal:
         channeldata.insert(0, download_channel_metadata(
@@ -327,11 +442,39 @@ def main(argv : Sequence[str] = None):
             username=args.username,
             password=args.password))
 
+    if args.all:
+        packages = list(query_installed_packages().keys())
+    else:
+        packages = args.package
+
+    # Identify the versions that are
+    # available for the packages the
+    # user has requested.
     packages = identify_packages(channeldata,
-                                 args.package,
+                                 packages,
                                  args.development)
 
-    install_packages(args.package, args.development, args.yes)
+    to_install = []
+    for pkgs in packages.values():
+        # select the newest available
+        # versions of all packages
+        pkg = pkgs[-1]
+
+        if Version(pkg.version) <= Version(pkg.installed_version):
+            log.debug(f'{pkg.name} is already up to date (available: '
+                      f'{pkg.version}, installed: {pkg.installed_version}) '
+                      '- ignoring.')
+        else:
+            to_install.append(pkg)
+
+    if len(to_install) == 0:
+        print('\nNo packages need updating.')
+        sys.exit(0)
+
+    if confirm_installation(to_install, args.yes):
+        install_packages(to_install)
+    else:
+        print('Aborting update')
 
 
 if __name__ == '__main__':
-- 
GitLab