diff --git a/share/fsl/sbin/update_fsl_package.py b/share/fsl/sbin/update_fsl_package.py new file mode 100755 index 0000000000000000000000000000000000000000..743cd36bb608c7381cf0be9f2c3d94a955410622 --- /dev/null +++ b/share/fsl/sbin/update_fsl_package.py @@ -0,0 +1,338 @@ +#!/usr/bin/env fslpython +"""Install/update one or more FSL packages using conda. + +This script is only intended to be used within conda-based FSL installations +that have been created with the fslinstaller.py script. It is not intended to +be used within conda environments that have had FSL packages installed into +them - in this scenario, conda should be used directly. + +The script performs the following steps: + + 1. Parses command line arguments (primarily the list of packages to + install/update). + + 2. Queries the FSL conda channel(s) to find the latest available versions of + the requested packages. + + 4. Runs "conda install" to install/update the packages. + +""" +# Note: Conda does have a Python API: +# https://docs.conda.io/projects/conda/en/latest/api/index.html +# +# But at the time of writing this script, most of its functionality is marked +# as beta. So this script currently interacts with conda via its command-line +# interface. + + +import argparse +import bisect +import dataclasses +import json +import os +import shlex +import string +import sys +import functools as ft +import subprocess as sp +import urllib.parse as urlparse +import urllib.request as urlrequest +from collections import defaultdict + +from typing import Dict, List, Union, Tuple, Optional, Sequence + + +PUBLIC_FSL_CHANNEL = 'https://fsl.fmrib.ox.ac.uk/fsldownloads/fslconda/public/' +INTERNAL_FSL_CHANNEL = 'https://fsl.fmrib.ox.ac.uk/fsldownloads/fslconda/internal/' + + +def identify_platform() -> str: + """Figures out what platform we are running on. Returns a platform + identifier string - one of: + + - "linux-64" (Linux, x86_64) + - "osx-64" (macOS, x86_64) + - "osx-arm64" (macOS, arm64) + """ + + platforms = { + ('linux', 'x86_64') : 'linux-64', + ('darwin', 'x86_64') : 'osx-64', + ('darwin', 'arm64') : 'osx-arm64', + } + + system = platform.system().lower() + cpu = platform.machine() + key = (system, cpu) + + if key not in platforms: + supported = ', '.join(['[{}, {}]' for s, c in platforms]) + raise Exception('This platform [{}, {}] is unrecognised or ' + 'unsupported! Supported platforms: {}'.format( + system, cpu, supported)) + + return platforms[key] + + +def http_request(url : str, + username : str = None, + password : str = None): + """Download JSON data from the given URL. """ + + if username is not None: + urlbase = urlparse.urlparse(url).netloc + pwdmgr = urlrequest.HTTPPasswordMgrWithDefaultRealm() + pwdmgr.add_password(None, urlbase, username, password) + handler = urlrequest.HTTPBasicAuthHandler(pwdmgr) + opener = urlrequest.build_opener(handler) + opener.open(url) + urlrequest.install_opener(opener) + + print(f'Downloading {url} ...') + + request = urlrequest.Request(url, method='GET') + response = urlrequest.urlopen(request) + payload = response.read() + + if len(payload) == 0: payload = {} + else: payload = json.loads(payload) + + return payload + + +@ft.total_ordering +@dataclasses.dataclass +class Package: + """Represents a single package file hosted on a conda channel. + + A package object corresponds to a specific version of a specific package, + for a specific platform. + """ + + name : str + """Package name.""" + + channel : str + """URL of the channel which hosts the package.""" + + platform : str + """Platform identifier.""" + + version : str + """Package version string.""" + + dependencies : List[str] + """References to all packages which this package depends on. Stored as + "package[ version-constraint]" strings. + """ + + @property + def development(self) -> bool: + """Return True if this is a development version of the package. """ + return 'dev' in self.version + + + @property + def parsed_version(self) -> Tuple[int]: + """Return the version as a tuple of integers. """ + parts = [] + + version = self.version.lower() + if version.startswith('v'): + version = version[1:] + + for part in version.split('.'): + + # FSL development releases may have ".postN" + if part.startswith('post'): + part = part[4:] + # FSL development releases may have ".devYYYYMMDD<githash>" + if part.startswith('dev'): + for end, char in enumerate(part[3:], 3): + if char not in string.digits: + break + part = part[3:end] + try: + parts.append(int(part)) + except Exception: + break + return tuple(parts) + + + def __lt__(self, pkg): + """Only valid when comparing another Package with the same name and + platform. + """ + return self.parsed_version < pkg.parsed_version + + + def __eq__(self, pkg): + """Only valid when comparing another Package with the same name and + platform. + """ + return self.parsed_version == pkg.parsed_version + + +def download_channel_metadata(channel_url : str, **kwargs) -> Tuple[Dict, Dict]: + """Downloads information about packages hosted at the given conda channel. + + Returns two dictionaries: + + - The first contains the contents of <channel_url>/channeldata.json, which + contains information about all packages on the channel, and the + platforms for which they are available. + + - The second contains the contents of + <channel_url>/<platform>/repodata.json for all platforms on the + channel. This dictionary has structure + + {platform : {pkgname : [ <pkginfo> ]}}, + + where <pkginfo> contains the contents of an entry for a single package + file entry from the "packages" section of a repodata.json file. + + Keyword arguments are passed through to the http_request function. + """ + + thisplat = identify_platform() + + # Load channel and platform metadata - the + # first gives us a list of all packages that + # are hosted on the channel and platforms + # they are built for, and the second gives us + # the available versions, and dependencies of + # each package. + cdata = http_request(f'{channel_url}/channeldata.json', **kwargs) + cdata['channel_url'] = channel_url + pdata = {} + + # only consider packages + # relevant to this platform + for platform in cdata['subdirs']: + if platform in ('noarch', thisplat): + purl = f'{channel_url}/{platform}/repodata.json' + pdata[platform] = http_request(purl) + + # Re-arrange the platform repodata to + # {platform : {pkgname : [pkgfiles]}} + # dicts, to make the version retrieval + # below a little easier. + pdatadicts = [] + for pdata in platformdata: + pdatadict = defaultdict(lambda : defaultdict(list)) + pdatadicts.append(pdatadict) + for platform, pkgs in pdata.items(): + for pkg in pkgs['packages'].values(): + if pkg['name'] in pkgnames: + pdatadict[platform][pkg['name']].append(pkg) + platformdata = pdatadicts + + return cdata, pdata + + +def identify_packages(channeldata : List[Tuple[Dict, Dict]], + pkgnames : Sequence[str], + development : bool) -> Dict[str, List[Package]]: + """Return metadata about the requested packages. + + Loads channel and platform metadata from the conda channels. Parses the + metadata, and creates a Package object for every requested package. Returns + a dict of {name : Package} mappings. + + channeldata: Sequence of channel data from one or more conda channels, as + returned by the download_channel_metadata function. + pkgnames: Sequence of package names to return metadata for + development: Whether or not to consider development versions of packages + """ + + # Figure out which channel to source each package from + # (we take the package from the first channel that hosts + # it, ignoring other channels - strict channel priority). + pkgchannels = {} + for pkgname in pkgnames: + for cdata, pdata in channeldata: + if pkgname in cdata['packages']: + pkgchannels[pkgname] = (cdata['channel_url'], pdata) + break + # This package is not available + # todo message user + else: + continue + + # Create Package objects for every available version of + # the requested packages. The packages dict has structure + # + # { platform : { : [Package, Package, ...]}} + # + # where the package lists are sorted from oldest to newest. + packages = defaultdict(lambda : defaultdict(list)) + for pkgname, (curl, pdata) in pkgchannels.items(): + for platform in pdata.keys(): + for pkgfile in pdata[platform][pkgname]: + version = pkgfile['version'] + depends = pkgfile['depends'] + pkg = Package(pkgname, curl, platform, version, depends) + if pkg.development and not (development): + continue + bisect.insort(packages[platform][pkgname], pkg) + + return packages + + +def parse_args(argv : Optional[Sequence[str]]) -> argparse.Namespace: + """Parses command-line arguments, returning an argparse.Namespace object. """ + + parser = argparse.ArgumentParser( + 'update_fsl_package', description='Install/update FSL packages') + parser.add_argument('package', nargs='+', + help='Package[s] to install/update') + parser.add_argument('-d', '--development', action='store_true', + help='Install development versions of packages if available') + parser.add_argument('-y', '--yes', action='store_true', + help='Install package[s] without prompting for confirmation') + + parser.add_argument('--internal', help=argparse.SUPPRESS) + parser.add_argument('--username', help=argparse.SUPPRESS) + parser.add_argument('--password', help=argparse.SUPPRESS) + + return parser.parse_args(argv) + + +def confirm_installation(packages : Sequence[Package]): + pass + + +def install_packages(packages : Sequence[Package]): + """Calls conda to update the given collection of packages. """ + + pkgs = [f'"{p.name}={p.version}"' for p in pkgs] + conda = op.expandvars('$FSLDIR/bin/conda') + cmd = f'{conda} install -n base -y ' + ' '.join(pkgs) + + sp.run(shlex.split(cmd)) + + +def main(argv : Sequence[str] = None): + """Entry point. Parses command-line arguments, then installs/updates the + specified packages. + """ + if argv is None: + argv = sys.argv[1:] + args = parse_args(argv) + + channeldata = [download_channel_metadata(PUBLIC_FSL_CHANNEL)] + if args.internal: + channeldata.insert(0, download_channel_metadata( + INTERNAL_FSL_CHANNEL, + username=args.username, + password=args.password)) + + packages = identify_packages(channeldata, + args.package, + args.development) + + install_packages(args.package, args.development, args.yes) + + +if __name__ == '__main__': + sys.exit(main())