diff --git a/share/fsl/sbin/update_fsl_package.py b/share/fsl/sbin/update_fsl_package similarity index 58% rename from share/fsl/sbin/update_fsl_package.py rename to share/fsl/sbin/update_fsl_package index 743cd36bb608c7381cf0be9f2c3d94a955410622..5886ae1b074c120278eb8443d7409407d01b1856 100755 --- a/share/fsl/sbin/update_fsl_package.py +++ b/share/fsl/sbin/update_fsl_package @@ -15,9 +15,9 @@ The script performs the following steps: the requested packages. 4. Runs "conda install" to install/update the packages. - """ # Note: Conda does have a Python API: +# # https://docs.conda.io/projects/conda/en/latest/api/index.html # # But at the time of writing this script, most of its functionality is marked @@ -29,11 +29,14 @@ import argparse import bisect import dataclasses import json +import logging import os +import platform import shlex import string import sys import functools as ft +import os.path as op import subprocess as sp import urllib.parse as urlparse import urllib.request as urlrequest @@ -42,10 +45,35 @@ from collections import defaultdict from typing import Dict, List, Union, Tuple, Optional, Sequence +log = logging.getLogger(__name__) + + PUBLIC_FSL_CHANNEL = 'https://fsl.fmrib.ox.ac.uk/fsldownloads/fslconda/public/' INTERNAL_FSL_CHANNEL = 'https://fsl.fmrib.ox.ac.uk/fsldownloads/fslconda/internal/' +def conda(cmd : str) -> str: + """Runs the conda command with the given arguments, returning its standard + output as a string. + """ + fsldir = os.environ['FSLDIR'] + condabin = op.join(fsldir, 'bin', 'conda') + + log.debug(f'Running {condabin} {cmd}') + + cmd = [condabin] + shlex.split(cmd) + result = sp.run(cmd, check=False, capture_output=True, text=True) + + log.debug(f'Exit code: {result.returncode}') + + if result.returncode != 0: + log.debug(f'Standard output:\n{result.stdout}') + log.debug(f'Standard error:\n{result.stderr}') + raise RuntimeError('Command returned error: {cmd}') + + return result.stdout + + def identify_platform() -> str: """Figures out what platform we are running on. Returns a platform identifier string - one of: @@ -71,9 +99,47 @@ def identify_platform() -> str: 'unsupported! Supported platforms: {}'.format( system, cpu, supported)) + log.debug(f'Detected platform: {platforms[key]}') + return platforms[key] +@ft.total_ordering +class Version: + """Class for parsing/comparing version strings. """ + def __init__(self, verstr): + self.verstr = verstr + + parts = [] + + verstr = verstr.lower() + if verstr.startswith('v'): + verstr = verstr[1:] + + for part in verstr.split('.'): + + # FSL development releases may have ".postN" + if part.startswith('post'): + part = part[4:] + # FSL development releases may have ".devYYYYMMDD<githash>" + if part.startswith('dev'): + for end, char in enumerate(part[3:], 3): + if char not in string.digits: + break + part = part[3:end] + try: + parts.append(int(part)) + except Exception: + break + self.parsed_version = tuple(parts) + + def __lt__(self, other): + return self.parsed_version < other.parsed_version + + def __eq__(self, other): + return self.parsed_version == other.parsed_version + + def http_request(url : str, username : str = None, password : str = None): @@ -88,7 +154,7 @@ def http_request(url : str, opener.open(url) urlrequest.install_opener(opener) - print(f'Downloading {url} ...') + log.debug(f'Downloading {url} ...') request = urlrequest.Request(url, method='GET') response = urlrequest.urlopen(request) @@ -100,6 +166,30 @@ def http_request(url : str, return payload +@ft.lru_cache +def query_installed_packages() -> Dict[str, str]: + """Uses conda to find out the versions of all packages installed in + $FSLDIR, which are sourced from the FSL conda channels. + """ + + channels = [PUBLIC_FSL_CHANNEL .rstrip('/'), + INTERNAL_FSL_CHANNEL.rstrip('/')] + + # conda info returns a list of dicts, + # one per package. We re-arrange this + # into a dict of {pkgname : version} + # mappings. + fsldir = os.environ['FSLDIR'] + info = json.loads(conda(f'list -p {fsldir} --json')) + pkgs = {} + + for pkg in info: + if pkg['base_url'].rstrip() in channels: + pkgs[pkg['name']] = pkg['version'] + + return pkgs + + @ft.total_ordering @dataclasses.dataclass class Package: @@ -133,44 +223,25 @@ class Package: @property - def parsed_version(self) -> Tuple[int]: - """Return the version as a tuple of integers. """ - parts = [] - - version = self.version.lower() - if version.startswith('v'): - version = version[1:] - - for part in version.split('.'): - - # FSL development releases may have ".postN" - if part.startswith('post'): - part = part[4:] - # FSL development releases may have ".devYYYYMMDD<githash>" - if part.startswith('dev'): - for end, char in enumerate(part[3:], 3): - if char not in string.digits: - break - part = part[3:end] - try: - parts.append(int(part)) - except Exception: - break - return tuple(parts) + def installed_version(self) -> str: + """Return the version of this package which is currently installed, or '-' + if not installed. + """ + return query_installed_packages().get(self.name, '-') def __lt__(self, pkg): """Only valid when comparing another Package with the same name and platform. """ - return self.parsed_version < pkg.parsed_version + return Version(self.version) < Version(pkg.version) def __eq__(self, pkg): """Only valid when comparing another Package with the same name and platform. """ - return self.parsed_version == pkg.parsed_version + return Version(self.version) == Version(pkg.version) def download_channel_metadata(channel_url : str, **kwargs) -> Tuple[Dict, Dict]: @@ -202,32 +273,27 @@ def download_channel_metadata(channel_url : str, **kwargs) -> Tuple[Dict, Dict]: # they are built for, and the second gives us # the available versions, and dependencies of # each package. - cdata = http_request(f'{channel_url}/channeldata.json', **kwargs) - cdata['channel_url'] = channel_url - pdata = {} + chandata = http_request(f'{channel_url}/channeldata.json', **kwargs) + chandata['channel_url'] = channel_url + platdata = {} # only consider packages # relevant to this platform - for platform in cdata['subdirs']: + for platform in chandata['subdirs']: if platform in ('noarch', thisplat): - purl = f'{channel_url}/{platform}/repodata.json' - pdata[platform] = http_request(purl) + purl = f'{channel_url}/{platform}/repodata.json' + platdata[platform] = http_request(purl) # Re-arrange the platform repodata to # {platform : {pkgname : [pkgfiles]}} - # dicts, to make the version retrieval - # below a little easier. - pdatadicts = [] - for pdata in platformdata: - pdatadict = defaultdict(lambda : defaultdict(list)) - pdatadicts.append(pdatadict) - for platform, pkgs in pdata.items(): - for pkg in pkgs['packages'].values(): - if pkg['name'] in pkgnames: - pdatadict[platform][pkg['name']].append(pkg) - platformdata = pdatadicts + # dicts, to make lookup by name easier. + platdatadict = defaultdict(lambda : defaultdict(list)) + for platform, pdata in platdata.items(): + for pkg in pdata['packages'].values(): + platdatadict[platform][pkg['name']].append(pkg) + platdata = platdatadict - return cdata, pdata + return chandata, platdata def identify_packages(channeldata : List[Tuple[Dict, Dict]], @@ -252,74 +318,123 @@ def identify_packages(channeldata : List[Tuple[Dict, Dict]], for pkgname in pkgnames: for cdata, pdata in channeldata: if pkgname in cdata['packages']: - pkgchannels[pkgname] = (cdata['channel_url'], pdata) + pkgchannels[pkgname] = (cdata, pdata) break # This package is not available - # todo message user else: + log.debug(f'Package {pkgname} is not available - ignoring.') continue # Create Package objects for every available version of # the requested packages. The packages dict has structure # - # { platform : { : [Package, Package, ...]}} + # {pkgname : [Package, Package, ...]} # # where the package lists are sorted from oldest to newest. - packages = defaultdict(lambda : defaultdict(list)) - for pkgname, (curl, pdata) in pkgchannels.items(): + packages = defaultdict(list) + for pkgname, (cdata, pdata) in pkgchannels.items(): + curl = cdata['channel_url'] for platform in pdata.keys(): for pkgfile in pdata[platform][pkgname]: - version = pkgfile['version'] - depends = pkgfile['depends'] - pkg = Package(pkgname, curl, platform, version, depends) - if pkg.development and not (development): + version = pkgfile['version'] + depends = pkgfile['depends'] + pkg = Package(pkgname, curl, platform, version, depends) + if pkg.development and (not development): continue - bisect.insort(packages[platform][pkgname], pkg) + bisect.insort(packages[pkgname], pkg) return packages +def confirm_installation(packages : Sequence[Package], yes : bool) -> bool: + """Asks the user for confirmation, before installing/updating the requested + packages. + """ + rows = [('Package name', 'Currently installed', 'Updating to'), + ('------------', '-------------------', '-----------')] + + for pkg in packages: + rows.append((pkg.name, pkg.installed_version, pkg.version)) + + len0 = max(len(r[0]) for r in rows) + len1 = max(len(r[1]) for r in rows) + len2 = max(len(r[2]) for r in rows) + + template = f'{{:{len0}}} {{:{len1}}} {{:{len2}}}' + + print('\nThe following updates are available:\n') + + for row in rows: + print(template.format(*row)) + + if yes: + return True + + response = input('\nProceed? [Y/n]: ') + + return response.strip().lower() in ('', 'y', 'yes') + + +def install_packages(packages : Sequence[Package]): + """Calls conda to update the given collection of packages. """ + + fsldir = os.environ['FSLDIR'] + packages = [f'"{p.name}={p.version}"' for p in packages] + cmd = f'install --no-deps -p {fsldir} -y ' + ' '.join(packages) + + print('\nInstalling packages...') + conda(cmd) + + def parse_args(argv : Optional[Sequence[str]]) -> argparse.Namespace: """Parses command-line arguments, returning an argparse.Namespace object. """ parser = argparse.ArgumentParser( 'update_fsl_package', description='Install/update FSL packages') - parser.add_argument('package', nargs='+', + parser.add_argument('package', nargs='*', help='Package[s] to install/update') parser.add_argument('-d', '--development', action='store_true', help='Install development versions of packages if available') parser.add_argument('-y', '--yes', action='store_true', help='Install package[s] without prompting for confirmation') + parser.add_argument('-a', '--all', action='store_true', + help='Install/update all installed FSL packages') parser.add_argument('--internal', help=argparse.SUPPRESS) parser.add_argument('--username', help=argparse.SUPPRESS) parser.add_argument('--password', help=argparse.SUPPRESS) + parser.add_argument('--verbose', help=argparse.SUPPRESS, action='store_true') - return parser.parse_args(argv) + args = parser.parse_args(argv) + if len(args.package) == 0 and not args.all: + parser.error('Specify at least one package, or use --all to update ' + 'all installed FSL packages.') -def confirm_installation(packages : Sequence[Package]): - pass - - -def install_packages(packages : Sequence[Package]): - """Calls conda to update the given collection of packages. """ - - pkgs = [f'"{p.name}={p.version}"' for p in pkgs] - conda = op.expandvars('$FSLDIR/bin/conda') - cmd = f'{conda} install -n base -y ' + ' '.join(pkgs) - - sp.run(shlex.split(cmd)) + return args def main(argv : Sequence[str] = None): """Entry point. Parses command-line arguments, then installs/updates the specified packages. """ + + if 'FSLDIR' not in os.environ: + print('$FSLDIR is not set - aborting') + sys.exit(1) + if argv is None: argv = sys.argv[1:] args = parse_args(argv) + logging.basicConfig() + if args.verbose: log.setLevel(logging.DEBUG) + else: log.setLevel(logging.INFO) + + # Download information about all + # available packages on the FSL + # conda channels. + print('Downloading FSL conda channel information ...') channeldata = [download_channel_metadata(PUBLIC_FSL_CHANNEL)] if args.internal: channeldata.insert(0, download_channel_metadata( @@ -327,11 +442,39 @@ def main(argv : Sequence[str] = None): username=args.username, password=args.password)) + if args.all: + packages = list(query_installed_packages().keys()) + else: + packages = args.package + + # Identify the versions that are + # available for the packages the + # user has requested. packages = identify_packages(channeldata, - args.package, + packages, args.development) - install_packages(args.package, args.development, args.yes) + to_install = [] + for pkgs in packages.values(): + # select the newest available + # versions of all packages + pkg = pkgs[-1] + + if Version(pkg.version) <= Version(pkg.installed_version): + log.debug(f'{pkg.name} is already up to date (available: ' + f'{pkg.version}, installed: {pkg.installed_version}) ' + '- ignoring.') + else: + to_install.append(pkg) + + if len(to_install) == 0: + print('\nNo packages need updating.') + sys.exit(0) + + if confirm_installation(to_install, args.yes): + install_packages(to_install) + else: + print('Aborting update') if __name__ == '__main__':