From 88e8e87c58bfbe562bdccd7a8f46789dc0b119ba Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauld.mccarthy@gmail.com>
Date: Tue, 7 Jul 2015 16:42:02 +0100
Subject: [PATCH] FEAT model fit plotting sorta working, but I need to think
 about the design a bit. Partial model fits are not being calculated
 correctly.

---
 fsl/data/featimage.py                         | 200 ++++++++++++++++--
 fsl/data/strings.py                           |   5 +-
 .../controls/timeseriescontrolpanel.py        |  23 +-
 fsl/fslview/views/timeseriespanel.py          | 180 ++++++++++++++--
 4 files changed, 372 insertions(+), 36 deletions(-)

diff --git a/fsl/data/featimage.py b/fsl/data/featimage.py
index b8abe038b..19d267a57 100644
--- a/fsl/data/featimage.py
+++ b/fsl/data/featimage.py
@@ -10,15 +10,126 @@ analysis.
 """
 
 import os.path as op
+import            glob
 
-import image as fslimage
+import numpy   as np
+
+import nibabel as nib
+
+import image   as fslimage
+
+
+def loadDesignMat(designmat):
+    """Loads a FEAT ``design.mat`` file. Returns a ``numpy`` array
+    containing the design matrix data, where the first dimension
+    corresponds to the data points, and the second to the EVs.
+    """
+
+    matrix = None
+    with open(designmat, 'rt') as f:
+
+        while True:
+            line = f.readline()
+            if line.strip() == '/Matrix':
+                break
+
+        matrix = np.loadtxt(f)
+
+    if matrix is None or matrix.size == 0:
+        raise RuntimeError('{} does not appear to be a '
+                           'valid design.mat file'.format(designmat))
+
+    return matrix
+
+
+def loadDesignCon(designcon):
+    """Loads a FEAT ``design.con`` file. Returns a tuple containing:
+    
+      - A dictionary of ``{contrastnum : name}`` mappings
+    
+      - A list of contrast vectors (each of which is a list itself).
+    """
+
+    matrix       = None
+    numContrasts = 0
+    names        = {}
+    with open(designcon, 'rt') as f:
+
+        while True:
+            line = f.readline().strip()
+
+            if line.startswith('/ContrastName'):
+                tkns       = line.split(None, 1)
+                num        = [c for c in tkns[0] if c.isdigit()]
+                num        = int(''.join(num))
+                name       = tkns[1].strip()
+                names[num] = name
+
+            elif line.startswith('/NumContrasts'):
+                numContrasts = int(line.split()[1])
+
+            elif line == '/Matrix':
+                break
+
+        matrix = np.loadtxt(f)
+
+    if matrix       is None             or \
+       numContrasts != matrix.shape[0]:
+        raise RuntimeError('{} does not appear to be a '
+                           'valid design.con file'.format(designcon))
+
+    # Fill in any missing contrast names
+    if len(names) != numContrasts:
+        for i in range(numContrasts):
+            if i + 1 not in names:
+                names[i + 1] = str(i + 1)
+
+    names     = [names[c + 1] for c in range(numContrasts)]
+    contrasts = []
+
+    for row in matrix:
+        contrasts.append(list(row))
+
+    return names, contrasts
+
+
+def loadDesignFsf(designfsf):
+    """
+    """
+
+    settings = {}
+
+    with open(designfsf, 'rt') as f:
+
+        for line in f.readlines():
+            line = line.strip()
+
+            if not line.startswith('set '):
+                continue
+
+            tkns = line.split(None, 2)
+
+            settings[tkns[1].strip()] = tkns[2]
+    
+    return settings
 
 
 def isFEATData(path):
+    
     keys = ['.feat{}filtered_func_data' .format(op.sep),
             '.gfeat{}filtered_func_data'.format(op.sep)]
 
-    isfeat = any([k in path for k in keys])
+    isfeatdir = any([k in path for k in keys])
+
+    dirname   = op.dirname(path)
+    hasdesfsf = op.exists(op.join(dirname, 'design.fsf'))
+    hasdesmat = op.exists(op.join(dirname, 'design.mat'))
+    hasdescon = op.exists(op.join(dirname, 'design.con'))
+
+    isfeat    = (isfeatdir and
+                 hasdesmat and
+                 hasdescon and
+                 hasdesfsf)
     
     return isfeat
 
@@ -32,36 +143,87 @@ class FEATImage(fslimage.Image):
             raise ValueError('{} does not appear to be data from a '
                              'FEAT analysis'.format(self.dataSource))
 
+        featDir     = op.dirname(self.dataSource)
+        settings    = loadDesignFsf(op.join(featDir, 'design.fsf'))
+        design      = loadDesignMat(op.join(featDir, 'design.mat'))
+        names, cons = loadDesignCon(op.join(featDir, 'design.con'))
 
-    # A FEATImage is an Image which has 
-    # some extra utility methods, something 
-    # like all of the below things:
+        self.__featDir       = featDir
+        self.__design        = design
+        self.__contrastNames = names
+        self.__contrasts     = cons
+        self.__settings      = settings
 
-    # def numParameterEstimates(self):
-    #     return 0
+        self.__pes           = [None] * self.numEVs()
+        self.__copes         = [None] * self.numContrasts()
+
+        
+
+    def getDesign(self):
+        return np.array(self.__design)
+        
     
-    # def numCOPEs(self):
-    #     return 0
+    def numPoints(self):
+        return self.__design.shape[0] 
 
     
-    # def getParameterEstimate(self, num):
-    #     pass
+    def numEVs(self):
+        return self.__design.shape[1]
 
     
-    # def getFullModelFit(self):
-    #     pass
+    def numContrasts(self):
+        return len(self.__contrasts)
 
     
-    # def getPartialModelFIt(self, contrast):
-    #     pass
+    def contrastNames(self):
+        return list(self.__contrastNames)
+
+
+    def contrasts(self):
+        return [list(c) for c in self.__contrasts]
+
+
+    def __getPEFile(self, prefix, ev):
+        prefix = op.join(self.__featDir, 'stats', '{}{}'.format(
+            prefix, ev + 1))
+        return glob.glob('{}.*'.format(prefix))[0]
+
+
+    def getPE(self, ev):
+
+        if self.__pes[ev] is None:
+            pefile = self.__getPEFile('pe', ev)
+            self.__pes[ev] = nib.load(pefile).get_data()
+
+        return self.__pes[ev]
 
     
-    # def getCOPEs(self):
-    #     pass
+    def getCOPE(self, num):
+        if self.__copes[num] is None:
+            copefile = self.__getPEFile('cope', num)
+            self.__copes[num] = nib.load(copefile).get_data()
 
+        return self.__copes[num] 
+        
 
-    # def getZStats(self):
-    #     pass
+    def fit(self, contrast, xyz):
+
+        x, y, z = xyz
+        numEVs  = self.numEVs()
+
+        if len(contrast) != numEVs:
+            raise ValueError('Contrast is wrong length')
+
+        X        = self.__design
+        data     = self.data[x, y, z, :]
+        modelfit = np.zeros(len(data))
+
+        for i in range(numEVs):
+
+            pe        = self.getPE(i)[x, y, z]
+            modelfit += np.dot(X[:, i], pe) * contrast[i]
+
+        return modelfit + data.mean()
 
     
     # def getThresholdedZStats(self):
diff --git a/fsl/data/strings.py b/fsl/data/strings.py
index 1478a86f3..b19a417cd 100644
--- a/fsl/data/strings.py
+++ b/fsl/data/strings.py
@@ -213,7 +213,8 @@ labels = TypeDict({
     'PlotPanel.xlabel'          : 'X',
     'PlotPanel.ylabel'          : 'Y',
 
-    'TimeSeriesControlPanel.currentFEATSettings' : 'FEAT settings for {}',
+    'TimeSeriesControlPanel.currentFEATSettings' : 'FEAT settings for '
+                                                   'selected overlay ({})',
     
 })
 
@@ -276,6 +277,8 @@ properties = TypeDict({
     'HistogramSeries.showOverlay'     : 'Show 3D histogram overlay',
 
     'FEATTimeSeries.plotFullModelFit' : 'Plot full model fit',
+    'FEATTimeSeries.plotPEFits'       : 'Plot {} fit',
+    'FEATTimeSeries.plotCOPEFits'     : 'Plot COPE {} ({}) fit',
 
     'OrthoEditProfile.selectionSize'          : 'Selection size',
     'OrthoEditProfile.selectionIs3D'          : '3D selection',
diff --git a/fsl/fslview/controls/timeseriescontrolpanel.py b/fsl/fslview/controls/timeseriescontrolpanel.py
index 1a7eb77d8..befaf57ae 100644
--- a/fsl/fslview/controls/timeseriescontrolpanel.py
+++ b/fsl/fslview/controls/timeseriescontrolpanel.py
@@ -153,8 +153,27 @@ class TimeSeriesControlPanel(fslpanel.FSLViewPanel):
             displayName=strings.labels[self, 'currentFEATSettings'].format(
                 display.name))
 
-        widg = props.makeWidget(self.__widgets, ts, 'plotFullModelFit')
+        full    = props.makeWidget(     self.__widgets, ts, 'plotFullModelFit')
+        pes     = props.makeListWidgets(self.__widgets, ts, 'plotPEFits')
+        copes   = props.makeListWidgets(self.__widgets, ts, 'plotCOPEFits')
         self.__widgets.AddWidget(
-            widg,
+            full,
             displayName=strings.properties[ts, 'plotFullModelFit'],
             groupName='currentFEATSettings')
+
+        for i, pe in enumerate(pes):
+            peName = 'PE {}'.format(i + 1)
+            self.__widgets.AddWidget(
+                pe,
+                displayName=strings.properties[ts, 'plotPEFits'].format(
+                    peName),
+                groupName='currentFEATSettings') 
+
+
+        copeNames = overlay.contrastNames()
+        for i, (cope, name) in enumerate(zip(copes, copeNames)):
+            self.__widgets.AddWidget(
+                cope,
+                displayName=strings.properties[ts, 'plotCOPEFits'].format(
+                    i, name),
+                groupName='currentFEATSettings') 
diff --git a/fsl/fslview/views/timeseriespanel.py b/fsl/fslview/views/timeseriespanel.py
index 32b7e4a1b..0ea9bb14a 100644
--- a/fsl/fslview/views/timeseriespanel.py
+++ b/fsl/fslview/views/timeseriespanel.py
@@ -38,9 +38,16 @@ class TimeSeries(plotpanel.DataSeries):
         self.coords  = map(int, coords)
         self.data    = overlay.data[coords[0], coords[1], coords[2], :]
 
+        
     def update(self, coords):
-        self.coords = map(int, coords)
+        
+        coords = map(int, coords)
+        if coords == self.coords:
+            return False
+        
+        self.coords = coords
         self.data   = self.overlay.data[coords[0], coords[1], coords[2], :]
+        return True
 
         
     def getData(self):
@@ -63,19 +70,159 @@ class FEATTimeSeries(TimeSeries):
     """
 
     plotFullModelFit = props.Boolean(default=False)
-    # plotResiduals          =            props.Boolean(default=False)
-    # plotParameterEstimates = props.List(props.Boolean(default=False))
-    # plotCopes              = props.List(props.Boolean(default=False))
-
-    # Reduced against what? It has to
-    # be w.r.t. a specific PE/COPE. 
-    # plotReducedData = props.Boolean(default=False)
+    plotPEFits       = props.List(props.Boolean(default=False))
+    plotCOPEFits     = props.List(props.Boolean(default=False))
 
 
     def __init__(self, *args, **kwargs):
         TimeSeries.__init__(self, *args, **kwargs)
         self.name = '{}_{}'.format(type(self).__name__, id(self))
 
+        numEVs   = self.overlay.numEVs()
+        numCOPEs = self.overlay.numContrasts()
+
+        for i in range(numEVs):
+            self.plotPEFits.append(False)
+
+        for i in range(numCOPEs):
+            self.plotCOPEFits.append(False) 
+
+        self.__fullModelTs =  None
+        self.__peTs        = [None] * numEVs
+        self.__copeTs      = [None] * numCOPEs
+        
+        self.addListener('plotFullModelFit',
+                         self.name,
+                         self.__plotFullModelFitChanged)
+        
+        for i, plotPEFit in enumerate(
+                self.plotPEFits.getPropertyValueList()):
+
+            def onChange(ctx, value, valid, name, pe=i):
+                self.__plotPEFitChanged(pe)
+
+            plotPEFit.addListener(self.name, onChange)
+
+        
+        for i, plotCOPEFit in enumerate(
+                self.plotCOPEFits.getPropertyValueList()):
+
+            def onChange(ctx, value, valid, name, cope=i):
+                self.__plotCOPEFitChanged(cope)
+
+            plotCOPEFit.addListener(self.name, onChange)
+ 
+            
+
+    def getModelTimeSeries(self):
+        modelts = []
+
+        if self.plotFullModelFit:
+            modelts.append(self.__fullModelTs)
+
+        for i in range(self.overlay.numEVs()):
+            if self.plotPEFits[i]:
+                modelts.append(self.__peTs[i])
+
+        for i in range(self.overlay.numContrasts()):
+            if self.plotCOPEFits[i]:
+                modelts.append(self.__copeTs[i]) 
+        
+        return modelts
+
+    
+    def __plotCOPEFitChanged(self, copenum):
+        if not self.plotCOPEFits[copenum]:
+            self.__copeTs[copenum] = None
+            return
+
+        con  = self.overlay.contrasts()[copenum]
+
+        copets = FEATModelFitTimeSeries(
+            con,
+            self.tsPanel,
+            self.overlay,
+            self.coords)
+        
+        copets.colour    = (0, 1, 0)
+        copets.alpha     = self.alpha
+        copets.label     = self.label
+        copets.lineWidth = self.lineWidth
+        copets.lineStyle = self.lineStyle
+
+        self.__copeTs[copenum] = copets 
+
+
+    def __plotPEFitChanged(self, evnum):
+        if not self.plotPEFits[evnum]:
+            self.__peTs[evnum] = None
+            return
+
+        con     = [0] * self.overlay.numEVs()
+        con[evnum] = 1
+
+        pets = FEATModelFitTimeSeries(
+            con,
+            self.tsPanel,
+            self.overlay,
+            self.coords)
+        
+        pets.colour    = (1, 0, 0)
+        pets.alpha     = self.alpha
+        pets.label     = self.label
+        pets.lineWidth = self.lineWidth
+        pets.lineStyle = self.lineStyle
+
+        self.__peTs[evnum] = pets
+
+
+    def __plotFullModelFitChanged(self, *a):
+        if not self.plotFullModelFit:
+            self.__fullModelTs = None
+            return
+
+        self.__fullModelTs = FEATModelFitTimeSeries(
+            [1] * self.overlay.numEVs(),
+            self.tsPanel,
+            self.overlay,
+            self.coords)
+        self.__fullModelTs.colour    = (0, 0, 1)
+        self.__fullModelTs.alpha     = self.alpha
+        self.__fullModelTs.label     = self.label
+        self.__fullModelTs.lineWidth = self.lineWidth
+        self.__fullModelTs.lineStyle = self.lineStyle
+
+        
+    def update(self, coords):
+        if not TimeSeries.update(self, coords):
+            return False
+            
+        if self.__fullModelTs is not None:
+            self.__fullModelTs.update(coords)
+
+        return True
+
+
+class FEATModelFitTimeSeries(TimeSeries):
+    
+    def __init__(self, contrast, *args, **kwargs):
+        TimeSeries.__init__(self, *args, **kwargs)
+        self.contrast = contrast
+        self.updateModelFit()
+
+        
+    def update(self, coords):
+        if not TimeSeries.update(self, coords):
+            return
+        self.updateModelFit()
+        
+
+    def updateModelFit(self):
+        x, y, z = self.coords
+
+        self.data = self.overlay.fit(self.contrast, (x, y, z))
+ 
+
 
 class TimeSeriesPanel(plotpanel.PlotPanel):
     """A panel with a :mod:`matplotlib` canvas embedded within.
@@ -154,6 +301,9 @@ class TimeSeriesPanel(plotpanel.PlotPanel):
         prevTs      = self.__currentTs
         prevOverlay = self.__currentOverlay
 
+        if prevTs is not None:
+            prevTs.removeGlobalListener(self._name)
+
         self.__currentTs      = None
         self.__currentOverlay = None
 
@@ -201,6 +351,8 @@ class TimeSeriesPanel(plotpanel.PlotPanel):
             self.__currentTs      = ts
             self.__currentOverlay = overlay
 
+        self.__currentTs.addGlobalListener(self._name, self.draw)
+
         
     def getCurrent(self):
         return self.__currentTs
@@ -209,16 +361,16 @@ class TimeSeriesPanel(plotpanel.PlotPanel):
     def draw(self, *a):
 
         self.__calcCurrent()
-        current = self.getCurrent()
+        current = self.__currentTs
 
         if self.showCurrent and \
            current is not None:
 
-            # Turn current into a list
-            # if it is not already one
-            try:    len(current)
-            except: current = [current]
+            extras = [current]
+
+            if isinstance(current, FEATTimeSeries):
+                extras += current.getModelTimeSeries()
 
-            self.drawDataSeries(current)
+            self.drawDataSeries(extras)
         else:
             self.drawDataSeries()
-- 
GitLab