Skip to content
Snippets Groups Projects
Forked from FSL / fslpy
1101 commits behind the upstream repository.
atlasq.py 23.98 KiB
#!/usr/bin/env python
#
# main.py - The atlasq program.
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
"""This module contains the FSL ``atlasq`` program, the successor to
``atlasquery``.
"""


from __future__ import print_function

import itertools as it
import              sys
import              argparse
import              textwrap
import              warnings
import              logging
import numpy     as np

# if h5py <= 2.7.1 is installed,
# it will be imported via nibabel,
# and will cause a numpy warning
# to be emitted.
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    import fsl.data.image as fslimage

# If wx is not present, then fsl.utils.platform
# will complain that it is not present.
logging.getLogger('fsl.utils.platform').setLevel(logging.ERROR)

import fsl.data.atlases as fslatlases  # noqa
import fsl.version      as fslversion  # noqa


log = logging.getLogger(__name__)


SHORT_QUERY_DELIM = '\t'


class IdentifyError(Exception):
    """Exception raised by the ``identifyAtlas`` when an atlas cannot be
    identified.
    """
    pass


class HelpFormatter(argparse.RawDescriptionHelpFormatter):
    """A custom ``argparse.HelpFormatter`` class which customises a few
    annoying things about default ``argparse`` behaviour.
    """
    def _format_usage(self, usage, actions, groups, prefix):

        # Inhibit the 'usage: ' prefix
        return argparse.RawDescriptionHelpFormatter._format_usage(
            self, usage, actions, groups, '')


def listAtlases(namespace):
    """List all available atlases. """
    atlases = fslatlases.listAtlases()

    if namespace.extended:
        for a in atlases:
            print('{} [{}]'.format(a.name, a.atlasID))
            print('  Spec:          {}'.format(a.specPath))
            print('  Type:          {}'.format(a.atlasType))
            print('  Labels:        {}'.format(len(a.labels)))
            for i in a.images:
                print('  Image:         {}'.format(i))
            for i in a.summaryImages:
                print('  Summary image: {}'.format(i))
            print()

    else:
        ids   = [a.atlasID for a in atlases]
        names = [a.name    for a in atlases]

        printColumns((ids, names), ('ID', 'Full name'))


def summariseAtlas(namespace):
    """Print information about one atlas. """

    a = identifyAtlas(namespace.atlas)
    print('{} [{}]'.format(a.name, a.atlasID))
    print('  Spec:          {}'.format(a.specPath))
    print('  Type:          {}'.format(a.atlasType))
    print('  Labels:        {}'.format(len(a.labels)))
    for i in a.images:
        print('  Image:         {}'.format(i))
    for i in a.summaryImages:
        print('  Summary image: {}'.format(i))

    indices = [l.index for l in a.labels]
    names   = [l.name  for l in a.labels]
    xs      = [l.x     for l in a.labels]
    ys      = [l.y     for l in a.labels]
    zs      = [l.z     for l in a.labels]

    printColumns(( indices, names,   xs,  ys,  zs),
                 ('Index', 'Label', 'X', 'Y', 'Z'))


def queryAtlas(namespace):
    """Query an atlas with coordinates or masks."""

    atlasDesc = identifyAtlas(namespace.atlas)
    wcoords   = namespace.coord
    vcoords   = namespace.voxel
    masks     = [fslimage.Image(m) for m in namespace.mask]
    worder    = namespace.coord_order
    vorder    = namespace.voxel_order
    morder    = namespace.mask_order
    atlas     = fslatlases.loadAtlas(atlasDesc.atlasID,
                                     loadSummary=namespace.label,
                                     resolution=namespace.resolution)

    mlabels, mprops = maskQuery( atlas, masks)
    wlabels, wprops = coordQuery(atlas, wcoords, False)
    vlabels, vprops = coordQuery(atlas, vcoords, True)

    order   = list(it.chain(morder,  worder,   vorder))
    labels  = list(it.chain(mlabels, wlabels,  vlabels))
    props   = list(it.chain(mprops,  wprops,   vprops))
    sources = list(it.chain(masks,   wcoords,  vcoords))
    types   = list(it.chain(['mask']       * len(masks),
                            ['coordinate'] * len(wcoords),
                            ['voxel']      * len(vcoords)))

    labels  = [l for (o, l) in sorted(zip(order, labels))]
    props   = [p for (o, p) in sorted(zip(order, props))]
    sources = [s for (o, s) in sorted(zip(order, sources))]
    types   = [t for (o, t) in sorted(zip(order, types))]

    if namespace.short: queryShortOutput(atlas, sources, types, labels, props)
    else:               queryLongOutput( atlas, sources, types, labels, props)


def queryShortOutput(atlas, sources, types, allLabels, allProps):
    """Called by ``queryAtlas`` when short output is requested. """

    for source, stype, labels, props in zip(sources,
                                            types,
                                            allLabels,
                                            allProps):

        if stype == 'coordinate':
            source = '{:0.2f} {:0.2f} {:0.2f}'.format(*source)
        elif stype == 'voxel':
            source = '{:0.0f} {:0.0f} {:0.0f}'.format(*source)
        elif stype == 'mask':
            source = source.dataSource

        results = []
        labels  = labelNames(atlas, labels)

        # Coordinate lookup for a label
        # atlas just returns a label
        if stype in ('coordinate', 'voxel') and \
           isinstance(atlas, fslatlases.LabelAtlas):
            results.append('{}'.format(labels[0]))

        # All other queries return a list of
        # labels and proportions. We output
        # them from highest proportion to
        # lowest
        else:
            for p, l in reversed(sorted(zip(props, labels))):
                results.append('{} {:0.4f}'.format(l, p))

        print(SHORT_QUERY_DELIM.join([stype, source] + results))


def queryLongOutput(atlas, sources, types, allLabels, allProps):
    """Called by ``queryAtlas`` when long output is requested. """

    def summaryCoord(source, stype, labels, props, names):

        label = labels[0]
        name  = names[ 0]

        if label is None: label = np.nan
        else:             label = int(label)

        fields = ['name', 'index']
        values = [name, label]

        if atlas.desc.atlasType == 'probabilistic':
            fields.append('summary value')
            values.append(label + 1)
        else:
            fields[1] = 'label'

        printColumns((fields, values))


    def proportions(source, stype, labels, props, names):

        if len(labels) == 0:
            print('No results')
            return

        proplabelnames = list(reversed(sorted(zip(props, labels, names))))
        props          = [pln[0] for pln in proplabelnames]
        labels         = [pln[1] for pln in proplabelnames]
        names          = [pln[2] for pln in proplabelnames]
        props          = ['{:0.4f}'.format(p) for p in props]

        titles  = ['name', 'index', 'proportion']
        columns = [ names,  labels,  props]

        if atlas.desc.atlasType == 'probabilistic':
            sumvals = [l + 1 for l in labels]
            titles .insert(2, 'summary value')
            columns.insert(2, sumvals)
        else:
            titles[1] = 'label'

        printColumns(columns, titles)


    for source, stype, labels, props in zip(sources,
                                            types,
                                            allLabels,
                                            allProps):

        if stype == 'coordinate':
            sourcestr = '{:0.2f} {:0.2f} {:0.2f}'.format(*source)
        elif stype == 'voxel':
            sourcestr = '{:0.0f} {:0.0f} {:0.0f}'.format(*source)
        elif stype == 'mask':
            sourcestr = '{}'.format(source.dataSource)

        title = '{} {}'.format(stype, sourcestr)

        print('-' * (4 + len(title)))
        print('| {} |'.format(title))
        print('-' * (4 + len(title)))
        print()

        names = labelNames(atlas, labels)

        if stype in ('coordinate', 'voxel') and \
           isinstance(atlas, fslatlases.LabelAtlas):
            summaryCoord(source, stype, labels, props, names)
        else:
            proportions(source, stype, labels, props, names)
        print()


def ohi(namespace):
    """Emulates the FSL ``atlasquery`` tool."""

    atlasDesc = None

    def dumpatlases():
        atlases = [a.name for a in fslatlases.listAtlases()]
        print('\n'.join(sorted(atlases)))

    if namespace.dumpatlases:
        dumpatlases()
        return

    for a in fslatlases.listAtlases():
        if a.name == namespace.atlas:
            atlasDesc = a
            break

    if atlasDesc is None:
        print('Invalid atlas name. Try one of:')
        dumpatlases()
        return

    # atlasquery always uses 2mm atlas
    # versions when a 2mm is available
    reses = [p[0] for p in atlasDesc.pixdims]

    if 2 in reses: res = 2
    else:          res = max(reses)

    # Mask query.
    if namespace.ohiMask is not None:

        mask          = fslimage.Image(namespace.ohiMask)
        labels, props = maskQuery(atlasDesc, [mask], resolution=res)
        labels        = labels[0]
        props         = props[ 0]

        for lbl, prop in zip(labels, props):

            if atlasDesc.atlasType == 'probabilistic':
                lbl = atlasDesc.labels[int(lbl)].name
            elif atlasDesc.atlasType == 'label':
                lbl = atlasDesc.find(value=int(lbl)).name
            print('{}:{:0.4f}'.format(lbl, prop))

    # Coordinate query
    else:
        coord         = namespace.coord.strip('"')
        coord         = [float(c) for c in coord.split(',')]
        labels, props = coordQuery(atlasDesc,
                                   [coord],
                                   False,
                                   resolution=res)

        labels = labels[0]
        props  = props[ 0]

        if atlasDesc.atlasType == 'label':

            labels = labels[0]

            if labels is None: label = 'Unclassified'
            else:              label = atlasDesc.find(value=int(labels)).name
            print('<b>{}</b><br>{}'.format(atlasDesc.name, label))

        elif atlasDesc.atlasType == 'probabilistic':

            labelStrs = []

            if len(labels) > 0:
                props, labels = zip(*reversed(sorted(zip(props, labels))))

            for label, prop in zip(labels, props):
                label = atlasDesc.labels[int(label)].name
                labelStrs.append('{:d}% {}'.format(int(round(prop)), label))

            if len(labelStrs) == 0: labels = 'No label found!'
            else:                   labels = ', '.join(labelStrs)

            print('<b>{}</b><br>{}'.format(atlasDesc.name, labels))


def atlasOrDesc(aord, *args, **kwargs):
    """If ``aord`` is an ``Atlas`` it is returned. Otherwise it is assumed to
    be an ``AtlasDescription``, in which case the corresponding ``Atlas`` is
    loaded and returned.
    """
    if isinstance(aord, fslatlases.Atlas):
        return aord
    else:
        return fslatlases.loadAtlas(aord.atlasID, *args, **kwargs)


def labelNames(atlas, labels):
    """Converts the given sequence of ``labels`` into region names. """

    names = []

    for l in labels:
        if l is None: names.append('No label')
        else:         names.append(atlas.desc.labels[int(l)].name)

    return names


def maskQuery(atlas, masks, *args, **kwargs):
    """Queries the ``atlas`` at the given ``masks``. """

    allLabels = []
    allProps  = []
    atlas     = atlasOrDesc(atlas, *args, **kwargs)

    for mask in masks:

        if isinstance(atlas, fslatlases.LabelAtlas):

            labels, props = atlas.maskLabel(mask)

            # We need to subtract 1 from summary
            # image label values to get the label
            # index, for probabilistic atlases.
            if atlas.desc.atlasType == 'probabilistic':
                labels = [l - 1 for l in labels]

        elif isinstance(atlas, fslatlases.ProbabilisticAtlas):

            labels = []
            props  = []
            zprops = atlas.maskValues(mask)

            for i in range(len(zprops)):
                if zprops[i] > 0:
                    props.append(zprops[i])
                    labels.append(atlas.desc.labels[i].index)

        allLabels.append(labels)
        allProps .append(props)

    return allLabels, allProps


def coordQuery(atlas, coords, voxel, *args, **kwargs):
    """Queries the ``atlas`` at the given ``coords``. """

    atlas     = atlasOrDesc(atlas, *args, **kwargs)
    allLabels = []
    allProps  = []

    for coord in coords:

        if isinstance(atlas, fslatlases.ProbabilisticAtlas):

            props   = atlas.values(coord, voxel=voxel)
            labels  = []
            nzprops = []

            for i, p in enumerate(props):
                if p != 0:
                    nzprops.append(p)
                    labels .append(atlas.desc.labels[i].index)

            allLabels.append(labels)
            allProps .append(nzprops)

        elif isinstance(atlas, fslatlases.LabelAtlas):

            label = atlas.label(coord, voxel=voxel)

            # we need to subtract 1 from the label
            # value to get the label index, for
            # probabilistic summary images.
            if atlas.desc.atlasType == 'probabilistic':

                # 0 == background
                if label == 0:        label = None
                if label is not None: label = label - 1

            allLabels.append([label])
            allProps .append([None])

    return allLabels, allProps


def identifyAtlas(idOrName):
    """Given a partial atlas ID or name, tries to find an atlas which
    uniquely matches it.
    """

    # TODO Use difflib or some fuzzy matching library?

    idOrName = idOrName.lower().strip()
    atlases  = fslatlases.listAtlases()

    allNames = [a.name   .lower() for a in atlases]
    allIDs   = [a.atlasID.lower() for a in atlases]

    # First test for an exact match
    nameMatches = [idOrName == n for n in allNames]
    idMatches   = [idOrName == i for i in allIDs]

    nameMatches = [i for i in range(len(nameMatches)) if nameMatches[i]]
    idMatches   = [i for i in range(len(idMatches))   if idMatches[  i]]

    if len(nameMatches) + len(idMatches) == 1:
        if len(nameMatches) == 1: return atlases[nameMatches[0]]
        else:                     return atlases[idMatches[  0]]

    # If no exact match, test for a partial match
    nameMatches = [idOrName in n for n in allNames]
    idMatches   = [idOrName in i for i in allIDs]

    nameMatches = [i for i in range(len(nameMatches)) if nameMatches[i]]
    idMatches   = [i for i in range(len(idMatches))   if idMatches[  i]]

    totalMatches = len(nameMatches) + len(idMatches)

    # No matches
    if totalMatches == 0:
        raise IdentifyError('Could not find any atlas '
                            'matching {}'.format(idOrName))

    # More than two matches, or a
    # different ID/name pair matched
    if totalMatches > 2 or (totalMatches == 2 and nameMatches != idMatches):

        possible = [allNames[m] for m in nameMatches] + \
                   [allIDs[  m] for m in idMatches]

        raise IdentifyError('{} matched multiple atlases! Could match one '
                            'of: {}'.format(idOrName, ', '.join(possible)))

    # Either one exact match to an ID or name,
    # or a match to an equivalent ID/name
    if len(nameMatches) == 1: return atlases[nameMatches[0]]
    else:                     return atlases[idMatches[  0]]


def printColumns(columns, titles=None, delim=' | ', sep=True, strip=False):
    """Convenience function which pretty-prints a collection of columns in a
    tabular format.

    :arg columns: A sequence of columns, where each column is a list of
                  strings.

    :arg titles:  A sequence of titles, one for each column.
    """

    if len(columns) == 0:
        return

    columns = list(columns)

    for i, c in enumerate(columns):
        col = list(map(str, c))
        if titles is not None: columns[i] = [titles[i]] + col
        else:                  columns[i] =               col

    colLens = []
    for col in columns:
        maxLen = max([len(r) for r in col])
        colLens.append(maxLen)

    fmtStr = delim.join(['{{:<{}s}}'.format(l) for l in colLens])

    if titles is not None and sep:
        titles    = [col[0]  for col in columns]
        columns   = [col[1:] for col in columns]
        separator = ['-' * l for l in colLens]

        print(fmtStr.format(*titles))
        print(fmtStr.format(*separator))

    nrows = len(columns[0])
    for i in range(nrows):

        row = [col[i] for col in columns]
        row = fmtStr.format(*row)
        if strip:
            row = row.strip()

        print(row)


def parseArgs(args):
    """Parses command line arguments, returning an ``argparse.Namespace``
    object.
    """

    # Show help if no args are provided
    if  len(args) == 0 or \
       (len(args) == 1 and args[0] in ('ohi', 'summary', 'query')):
        args = list(args) + ['-h']

    # Hack to make argparse accept
    # coordinates with a negative sign
    # (ohi/atlasquery interface only)
    if args[0] == 'ohi':
        try:
            cidx           = args.index('-c')
            coord          = args[cidx + 1]
            args[cidx + 1] = '"{}"'.format(coord)
        except:
            pass

    prolog = 'FSL atlasq {}'.format(fslversion.__version__)

    helps = {
        'ohi'     : 'Emulate the FSL atlasquery tool',
        'list'    : 'List available atlases',
        'summary' : 'Print a summary of one atlas',
        'query'   : 'Query an atlas at specific coordinates'
    }

    usages = {
        'main'    : 'usage: atlasq [-h] command [options]',
        'ohi'     : textwrap.dedent("""
                     usage: atlasq ohi -h
                            atlasq ohi --dumpatlases
                            atlasq ohi -a atlas -c X,Y,Z
                            atlasq ohi -a atlas -m mask
        """).strip(),
        'list'    : 'usage: atlasq list [-e]',
        'summary' : 'usage: atlasq summary atlas',
        'query' : textwrap.dedent("""
                     usage: atlasq query atlas [options] -m mask  [-m mask ...]
                     usage: atlasq query atlas [options] -c X Y Z [-c X Y Z...]
                     usage: atlasq query atlas [options] -v X Y Z [-v X Y Z...]
                     usage: atlasq query atlas [options] -v X Y Z \\
                                                        [-c X Y Z [-m mask]...]
        """).strip(),
    }

    for k in usages:
        usages[k] = '{}\n\n{}'.format(prolog, usages[k])

    parser = argparse.ArgumentParser(
        prog='atlasq',
        usage=usages['main'],
        formatter_class=HelpFormatter)

    subParsers = parser.add_subparsers(title='Commands', dest='command')
    ohiParser = subParsers.add_parser(
        'ohi',
        help=helps['ohi'],
        usage=usages['ohi'],
        formatter_class=HelpFormatter)
    listParser = subParsers.add_parser(
        'list',
        help=helps['list'],
        usage=usages['list'],
        formatter_class=HelpFormatter)
    sumParser = subParsers.add_parser(
        'summary',
        help=helps['summary'],
        usage=usages['summary'],
        formatter_class=HelpFormatter)
    queryParser = subParsers.add_parser(
        'query',
        help=helps['query'],
        usage=usages['query'],
        formatter_class=HelpFormatter)

    # This is a custom argparse.Action used by the
    # query command parser which keeps track of the
    # order in which query arguments are passed. The
    # three query types (mask, voxel, coord) are
    # each added to separate lists. For each type, a
    # second list (called mask_order, voxel_order,
    # and coord_order) is maintained which stores
    # the index of each query across all types.
    class QueryAction(argparse.Action):

        queryCount = 0

        def __init__(self, *args, **kwargs):
            argparse.Action.__init__(self, *args, **kwargs)

        def __call__(self, parser, namespace, values, option_string=None):

            dest  = getattr(namespace,                   self.dest,  None)
            order = getattr(namespace, '{}_order'.format(self.dest), None)

            if dest  is None: dest  = []
            if order is None: order = []

            dest .append(values)
            order.append(QueryAction.queryCount)

            setattr(namespace,                   self.dest,  dest)
            setattr(namespace, '{}_order'.format(self.dest), order)

            QueryAction.queryCount += 1


    # OldHorribleInterface parser
    ohiParser.add_argument(
        '-a', '--atlas',
        help='Name of atlas to use')
    ohiParser.add_argument(
        '-V', '--verbose',
        action='store_true',
        help='Switch on diagnostic messages')
    ohiSubParser = ohiParser.add_mutually_exclusive_group()
    ohiSubParser.add_argument(
        '-m', '--mask',
        dest='ohiMask',
        metavar='MASK',
        help='A mask image to use during structural lookups')
    ohiSubParser.add_argument(
        '-c', '--coord',
        help='Coordinate to query')
    ohiParser.add_argument(
        '--dumpatlases',
        action='store_true',
        help='Dump a list of the available atlases')

    # List parser
    listParser.add_argument(
        '-e', '--extended',
        action='store_true',
        help='Print more information about each atlas')

    # Summary parser
    sumParser.add_argument('atlas', help='Name or ID of atlas to summarise')

    # Query parser
    queryParser.add_argument(
        'atlas',
        help='Name or ID of atlas to summarise.')
    queryParser.add_argument(
        '-r', '--resolution',
        type=float,
        help='Desired atlas resolution (mm). Default is highest available '
             'resolution.')
    queryParser.add_argument(
        '-s', '--short',
        action='store_true',
        help='Output in short (machine-friendly) format.')
    queryParser.add_argument(
        '-l', '--label',
        action='store_true',
        help='Query label/maxprob version of atlas (for probabilistic '
             'atlases).')
    queryParser.add_argument(
        '-m', '--mask',
        action=QueryAction,
        help='Mask to query with.')
    queryParser.add_argument(
        '-c', '--coord',
        nargs=3,
        type=float,
        metavar=('X', 'Y', 'Z'),
        action=QueryAction,
        help='World coordinates to look up.')
    queryParser.add_argument(
        '-v', '--voxel',
        nargs=3,
        type=float,
        metavar=('X', 'Y', 'Z'),
        action=QueryAction,
        help='Voxel coordinates to look up. Must be in terms of the atlas '
        'at the specified (or default) --resolution.')

    namespace = parser.parse_args(args)

    if namespace.command != 'query':
        return namespace

    # Make life easier for the queryAtlas code
    if namespace.mask  is None:               namespace.mask        = []
    if namespace.coord is None:               namespace.coord       = []
    if namespace.voxel is None:               namespace.voxel       = []
    if not hasattr(namespace, 'mask_order'):  namespace.mask_order  = []
    if not hasattr(namespace, 'coord_order'): namespace.coord_order = []
    if not hasattr(namespace, 'voxel_order'): namespace.voxel_order = []

    return namespace


def main(args=None):
    """Entry point for ``atlasq``. Parses arguments, and runs the requested
    command.
    """

    if args is None:
        args = sys.argv[1:]

    # Parse command line arguments
    namespace = parseArgs(args)

    # Initialise the atlas library
    fslatlases.rescanAtlases()

    # Run the command
    try:
        if   namespace.command == 'list':    listAtlases(   namespace)
        elif namespace.command == 'query':   queryAtlas(    namespace)
        elif namespace.command == 'summary': summariseAtlas(namespace)
        elif namespace.command == 'ohi':     ohi(           namespace)

    except (IdentifyError, fslatlases.MaskError) as e:
        print(str(e))
        return 1

    return 0


def atlasquery_emulation(args=None):
    """Entry point for ``atlasquery``. Runs as ``atlasq`` in ``ohi``
    mode.
    """

    if args is None:
        args = sys.argv[1:]

    return main(['ohi'] + args)


if __name__ == '__main__':
    sys.exit(main())