Skip to content
Snippets Groups Projects
featimage.py 8.51 KiB
Newer Older
#!/usr/bin/env python
#
# featimage.py - An Image subclass which has some FEAT-specific functionality.
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
"""This module provides the :class:`FeatImage` class, a subclass of
:class:`.Image` designed for the ``filtered_func_data`` file of a FEAT
analysis.
"""

import os.path as op
import numpy   as np

import nibabel as nib

import image   as fslimage


def loadDesignMat(designmat):
    """Loads a FEAT ``design.mat`` file. Returns a ``numpy`` array
    containing the design matrix data, where the first dimension
    corresponds to the data points, and the second to the EVs.
    """

    matrix = None
    with open(designmat, 'rt') as f:

        while True:
            line = f.readline()
            if line.strip() == '/Matrix':
                break

        matrix = np.loadtxt(f)

    if matrix is None or matrix.size == 0:
        raise RuntimeError('{} does not appear to be a '
                           'valid design.mat file'.format(designmat))

    return matrix


def loadDesignCon(designcon):
    """Loads a FEAT ``design.con`` file. Returns a tuple containing:
    
      - A dictionary of ``{contrastnum : name}`` mappings
    
      - A list of contrast vectors (each of which is a list itself).
    """

    matrix       = None
    numContrasts = 0
    names        = {}
    with open(designcon, 'rt') as f:

        while True:
            line = f.readline().strip()

            if line.startswith('/ContrastName'):
                tkns       = line.split(None, 1)
                num        = [c for c in tkns[0] if c.isdigit()]
                num        = int(''.join(num))
                name       = tkns[1].strip()
                names[num] = name

            elif line.startswith('/NumContrasts'):
                numContrasts = int(line.split()[1])

            elif line == '/Matrix':
                break

        matrix = np.loadtxt(f)

    if matrix       is None             or \
       numContrasts != matrix.shape[0]:
        raise RuntimeError('{} does not appear to be a '
                           'valid design.con file'.format(designcon))

    # Fill in any missing contrast names
    if len(names) != numContrasts:
        for i in range(numContrasts):
            if i + 1 not in names:
                names[i + 1] = str(i + 1)

    names     = [names[c + 1] for c in range(numContrasts)]
    contrasts = []

    for row in matrix:
        contrasts.append(list(row))

    return names, contrasts


def loadDesignFsf(designfsf):
    """
    """

    settings = {}

    with open(designfsf, 'rt') as f:

        for line in f.readlines():
            line = line.strip()

            if not line.startswith('set '):
                continue

            tkns = line.split(None, 2)

            key = tkns[1].strip()
            val = tkns[2].strip().strip("'").strip('"')

            settings[key] = val
    keys = ['.feat{}filtered_func_data' .format(op.sep),
            '.gfeat{}filtered_func_data'.format(op.sep)]

    isfeatdir = any([k in path for k in keys])

    dirname   = op.dirname(path)
    hasdesfsf = op.exists(op.join(dirname, 'design.fsf'))
    hasdesmat = op.exists(op.join(dirname, 'design.mat'))
    hasdescon = op.exists(op.join(dirname, 'design.con'))

    isfeat    = (isfeatdir and
                 hasdesmat and
                 hasdescon and
                 hasdesfsf)
    
    return isfeat


class FEATImage(fslimage.Image):

    def __init__(self, image, **kwargs):
        fslimage.Image.__init__(self, image, **kwargs)

        if not isFEATData(self.dataSource):
            raise ValueError('{} does not appear to be data from a '
                             'FEAT analysis'.format(self.dataSource))

        featDir     = op.dirname(self.dataSource)
        settings    = loadDesignFsf(op.join(featDir, 'design.fsf'))
        design      = loadDesignMat(op.join(featDir, 'design.mat'))
        names, cons = loadDesignCon(op.join(featDir, 'design.con'))
        self.__analysisName  = op.splitext(op.basename(featDir))[0]
        self.__featDir       = featDir
        self.__design        = design
        self.__contrastNames = names
        self.__contrasts     = cons
        self.__settings      = settings
        self.__evNames       = self.__getEVNames()
        self.__pes           = [None] * self.numEVs()
        self.__copes         = [None] * self.numContrasts()

        if 'name' not in kwargs:
            self.name = '{}.feat: {}'.format(
                self.__analysisName, self.name)

            
    def __getEVNames(self):

        numEVs = self.numEVs()
        
        titleKeys = filter(
            lambda s: s.startswith('fmri(evtitle'),
            self.__settings.keys())

        derivKeys = filter(
            lambda s: s.startswith('fmri(deriv_yn'),
            self.__settings.keys())

        evnames = []

        for titleKey, derivKey in zip(titleKeys, derivKeys):

            # Figure out the ev number from
            # the design.fsf key - skip over
            # 'fmri(evtitle' (an offset of 12)
            evnum = int(titleKey[12:-1])

            # Sanity check - the evnum
            # for the deriv_yn key matches
            # that for the evtitle key
            if evnum != int(derivKey[13:-1]):
                raise RuntimeError('design.fsf seem to be corrupt')

            title = self.__settings[titleKey]
            deriv = self.__settings[derivKey]

            if deriv == '0':
                evnames.append(title)
            else:
                evnames.append(title)
                evnames.append('{} - {}'.format(title, 'temporal derivative'))

        if len(evnames) != numEVs:
            raise RuntimeError('The number of EVs in design.fsf does not '
                               'match the number of EVs in design.mat')

        return evnames


    def getAnalysisName(self):
        return self.__analysisName
        

    def getDesign(self):
        return np.array(self.__design)
        
    def numPoints(self):
        return self.__design.shape[0] 
    def numEVs(self):
        return self.__design.shape[1]

    def evNames(self):
        return list(self.__evNames)

    def numContrasts(self):
        return len(self.__contrasts)
    def contrastNames(self):
        return list(self.__contrastNames)


    def contrasts(self):
        return [list(c) for c in self.__contrasts]


    def __getStatsFile(self, prefix, ev=None):

        if ev is not None: prefix = '{}{}'.format(prefix, ev + 1)

        prefix = op.join(self.__featDir, 'stats', prefix)
        
        return glob.glob('{}.*'.format(prefix))[0]


    def getPE(self, ev):

        if self.__pes[ev] is None:
            pefile = self.__getStatsFile('pe', ev)
            self.__pes[ev] = nib.load(pefile).get_data()

        return self.__pes[ev]

    def getResiduals(self):
        
        if self.__residuals is None:
            resfile          = self.__getStatsFile('res4d')
            self.__residuals = nib.load(resfile).get_data()
        
        return self.__residuals

            copefile = self.__getStatsFile('cope', num)
            self.__copes[num] = nib.load(copefile).get_data()
    def fit(self, contrast, xyz, fullmodel=False):
        """

        Passing in a contrast of all 1s, and ``fullmodel=True`` will
        get you the full model fit. Pass in ``fullmodel=False`` for
        all other contrasts, otherwise the model fit values will not
        be scaled correctly.
        """
            contrast /= np.sqrt((contrast ** 2).sum())

        x, y, z = xyz
        numEVs  = self.numEVs()

        if len(contrast) != numEVs:
            raise ValueError('Contrast is wrong length')
        X        = self.__design
        data     = self.data[x, y, z, :]
        modelfit = np.zeros(len(data))

        for i in range(numEVs):

            pe        = self.getPE(i)[x, y, z]

    def reducedData(self, xyz, contrast, fullmodel=False):
        """

        Passing in a contrast of all 1s, and ``fullmodel=True`` will
        get you the model fit residuals.
        """

        x, y, z   = xyz
        residuals = self.getResiduals()[x, y, z, :]
        modelfit  = self.fit(contrast, xyz, fullmodel)

        return residuals + modelfit