#!/usr/bin/env fslpython

import sys
import numpy as np
import nibabel as nib
import datetime
import pandas as pd

class EddyHigh_b_FeedsType(object):
    """ The purpose of this class is to make all the comparisons between
        newly estimated and precomputed eddy results for plain vanilla eddy
        on high b-value data.
    """

    def __init__(self,mask,corr,precomp_corr):
        # Read corrected images and make sure dimensions are right
        try:
            self._mask = nib.load(mask,mmap=False)
            self._corr = nib.load(corr,mmap=False)
            self._precomp_corr = nib.load(precomp_corr,mmap=False)
        except Exception as e:
            print(str(e))
            raise Exception('EddyHigh_b_FeedsType:__init__:Error opening corrected image files')

        if not (all(self._mask.header['dim'][1:4] == self._corr.header['dim'][1:4]) and
                all(self._mask.header['dim'][1:4] == self._precomp_corr.header['dim'][1:4])):
            raise Exception('EddyHigh_b_FeedsType:__init__:Size mismatch in first three dimensions of corrected images')
        if not (self._mask.header['dim'][4] == 1 and
                self._corr.header['dim'][4] == self._precomp_corr.header['dim'][4]):
            raise Exception('EddyHigh_b_FeedsType:__init__:Size mismatch in fourth dimension of corrected images')

        # Compare images and create statistics of differences
        try:
            mask = self._mask.get_fdata()
            mask = (mask > 0).astype(float)
            corrdiff = self._corr.get_fdata()
            corrdiff = abs(corrdiff - self._precomp_corr.get_fdata())

            self._corrdiffmeans = np.zeros(corrdiff.shape[3])
            for vol in range(0, corrdiff.shape[3]):
                tmpdiff = np.multiply(mask,corrdiff[:,:,:,vol])
                self._corrdiffmeans[vol] = np.array(mask.shape).prod() * tmpdiff.mean() / mask.sum()

        except Exception as e:
            print(str(e))
            raise Exception('EddyHigh_b_FeedsType:__init__:Error calculating image statistics')

    def MeanDiffsOfCorrectedImages(self):
        return self._corrdiffmeans

def main(argv):
    # This is the main program that tests the output from running
    # vanilla eddy on high b-value data.
    try:
        if len(argv) != 8:
            print('EddyHigh_b_Feeds.py usage: EddyHigh_b_Feeds.py output_dir prefix mask corrected precomputed_corrected' \
                  'allowed_mean_diff_corrected allowed_max_diff_corrected ')
            sys.exit(1)
        else:
            output_dir = argv[1]
            output_prefix = argv[2]
            mask = argv[3]
            corrected = argv[4]
            precomputed_corrected = argv[5]
            allowed_mean_diff_corrected = float(argv[6])
            allowed_max_diff_corrected = float(argv[7])

        # Try to create EddyFeedsType object (involves reading all files)
        try:
            ef = EddyHigh_b_FeedsType(mask,corrected,precomputed_corrected)
        except Exception as e:
            print(str(e))
            print('main: Error when creating EddyHigh_b_FeedsType object.')
            sys.exit(1)

        try:
            passes_test = True
            # Check pass/fail based on corrected images and fields
            if ef.MeanDiffsOfCorrectedImages().mean() > allowed_mean_diff_corrected or \
               ef.MeanDiffsOfCorrectedImages().max() > allowed_max_diff_corrected:
                passes_test = False

        except Exception as e:
            print(str(e))
            print('main: Failed calculating stats for test.')
            sys.exit(1)

        # Write report
        try:
            fp = open(output_dir + '/' + output_prefix + '_EddyHigh_b_Report.txt','w')
        except Exception as e:
            print(str(e))
            print('main: Cannot open report file: ' + output_dir + '/' + output_prefix + '_EddyHigh_b_Report.txt')
            sys.exit(1)
        else:
            try:
                fp.write('EddyHigh_b_Feeds was run on ' + datetime.datetime.now().strftime("%Y-%m-%d %H:%M") + '\n')
                fp.write('With the command' + ' '.join(argv) + '\n')
                if passes_test:
                    fp.write('\nOverall the test passed\n')
                else:
                    fp.write('\nOverall the test failed\n')
                # Report on differences in corrected images
                fp.write('\nThe absolute differences, averaged across the mask, for the corrected images were: ' + \
                         ' '.join(["{0:.4f}".format(elem) for elem in ef.MeanDiffsOfCorrectedImages()]) + '\n')
                fp.write('That gives mean error: ' + "{0:.4f}".format(ef.MeanDiffsOfCorrectedImages().mean()) + \
                         ', and a max error: ' + "{0:.4f}".format(ef.MeanDiffsOfCorrectedImages().max()) + '\n')
                fp.write('The allowed mean of averaged differences for the corrected images was: ' + \
                         "{0:.4f}".format(allowed_mean_diff_corrected) + '\n')
                fp.write('The allowed maximum of averaged differences for the corrected images was: ' + \
                         "{0:.4f}".format(allowed_max_diff_corrected) + '\n')
                if ef.MeanDiffsOfCorrectedImages().mean() > allowed_mean_diff_corrected or \
                   ef.MeanDiffsOfCorrectedImages().max() > allowed_max_diff_corrected:
                    fp.write('Based on these criteria the test failed\n')
                else:
                    fp.write('Based on these criteria the test passed\n')
                fp.close()
            except Exception as e:
                print(str(e))
                print('main: Problem writing report file: ' + argv[1] + '/' + argv[2] + '_EddyHigh_b_Report.txt')
                sys.exit(1)
    except Exception as e:
        print(str(e))
        print('main: Unknown problem in body of function')
        sys.exit(1)

    if passes_test:
        sys.exit(0)
    else:
        sys.exit(1)


if __name__ == "__main__":
    main(sys.argv)