from subprocess import check_output
import nibabel as nb
import numpy as np
import hashlib
import glob
import os.path as op
from feedsUtils import normalisedImageError

def calcHash(value):
    hashObj = hashlib.md5()
    hashObj.update(value)
    return hashObj.hexdigest()

def hashCompare(file1, file2):
    hash1 = calcHash(open(file1, 'rb').read())
    hash2 = calcHash(open(file2, 'rb').read())
    return (hash1 == hash2)

def normImageError(im1, im2):
    im1data = nb.load(im1).get_data()
    im2data = nb.load(im2).get_data()
    return np.abs((im1data-im2data)).mean() / np.abs(im1data).mean()

def run(cmd):
    print("RUNNING: " + cmd)
    check_output(cmd, shell=True)

# compare design files (excluding design.fsf) between original (benchmark) and new FEAT dirs
def checkFeatDesigns(origFeatDir,newFeatDir,strict):
    newDesigns = glob.glob(op.join(newFeatDir, "**", "design.*[!fsf]"),recursive=True)
    for d in newDesigns:
        print("TESTING: " + d)
        isEqual = hashCompare(d, d.replace(newFeatDir, origFeatDir))
        if not strict or isEqual:
            continue
        else:
            print("FAIL: Design test failed: {0}".format(d))
            sys.exit(1)

# compare image files between original (benchmark) and new FEAT dirs
def checkFeatImages(origFeatDir,newFeatDir,strict):
    newImages = glob.glob(op.join(newFeatDir,"**", "*.nii.gz"), recursive=True)
    for im in newImages:
        print("TESTING: " + im)
        imgError = normalisedImageError(im, im.replace(newFeatDir, origFeatDir))
        if (imgError < 0.001):
            continue
        elif not strict and (imgError < 0.01):
            continue
        else:
            print("FAIL: Image test failed: {0}".format(im))
            sys.exit(1)