#!/usr/bin/env fslpython
"""Utility functions for use by FEEDS tests. """


import os
import os.path as op
import uuid
import tempfile
import subprocess
import collections

import numpy as np


def fslbin(name):
    return op.join(os.environ["FSLDIR"], "bin", name)


def imrm(imageName):
    
    cmd = [fslbin("imrm"), imageName]
    return subprocess.call(cmd)


def readfsf(filename, keys):

    if not isinstance(keys, collections.Sequence):
        keys = [keys]
    
    values = {}
    
    with open(filename, 'rt') as f:
        for line in f:
            
            if not line.startswith('set '):
                continue

            line = line.split()
            k    = line[1]
            v    = ' '.join(line[2:])

            if k in keys:
                values[k] = v

    return [values[key] for key in keys]


def writefsf(filename, keys, values):
    
    if not isinstance(keys,   collections.Sequence): keys   = [keys]
    if not isinstance(values, collections.Sequence): values = [values]

    if len(keys) != len(values):
        raise ValueError('len(keys) != len(values)')

    with open(filename, 'at') as f:
        for k, v in zip(keys, values):

            f.write('\nset {} {}'.format(k, v))


def distance(coord1, coord2):
    """Returns the euclidean distance between the provided coordinates. """
    coord1 = np.array(coord1)
    coord2 = np.array(coord2)

    return np.sqrt(np.sum(coord1 * coord2))


def normalisedImageError(image1, image2):

    tempFile = op.join(tempfile.gettempdir(), str(uuid.uuid4()))

    cmd = [fslbin("fslmaths"),
           image1,
           "-sub", image2,
           "-div", image1,
           "-abs",
           tempFile]
    
    subprocess.check_call(cmd, stderr=subprocess.STDOUT)

    cmd    = [fslbin("fslstats"), tempFile, "-m"]
    result = float(subprocess.check_output(cmd))

    imrm(tempFile)
    
    return result


def unscaledImageError(image1, image2):

    cmd = [fslbin("fslstats"),
           image1,
           "-d", image2,
           "-a",
           "-m"]
    
    return float(subprocess.check_output(cmd))


def testImage(image1, image2, tolerance, normalised=True):
    
    if normalised: error = normalisedImageError(image1, image2)
    else:          error = unscaledImageError(  image1, image2)

    return (error <= tolerance, error)


def testNumber(num1, num2, tolerance, normalised=False):

    num1 = np.array(num1, dtype=np.float)
    num2 = np.array(num2, dtype=np.float)

    if normalised: error = np.abs(num1 - num2) / num1
    else:          error = np.abs(num1 - num2)

    if error.size > 1:
        error = error.mean()

    return (error <= tolerance, error)
       

def worstError(errors, strict):

    results, errors = zip(*errors)

    if   not all(results):                        return 1
    elif strict and any([e > 0 for e in errors]): return 1
    else:                                         return 0