Skip to content
Snippets Groups Projects
feedsRun 4.43 KiB
#!/usr/bin/env fslpython
"""Test that fslcomplex produces correct outputs. Specifically tests that the
output files have orientation info that matches that of the input files.
"""

import numpy      as np
import nibabel    as nib
import subprocess as sp
import os.path    as op
import               shlex
import               sys
import               traceback

from fsl.utils.tempdir import tempdir
from fsl.data.image    import addExt


def sprun(cmd):
    print(f'RUN {cmd}')
    sp.run(shlex.split(cmd), check=True)


def imtest(path):
    try:
        addExt(path)
        return True
    except Exception:
        return False


# radio=[True|False] -> whether or not
# to add a flip on the affine X axis
def create_test_input_data(radio, seed=1):

    np.random.seed(seed)

    real   = np.random.randint(0, 100, (10, 10, 10)).astype(np.float32)
    imag   = np.random.randint(0, 100, (10, 10, 10)).astype(np.float32)
    affine = np.diag([3, 3, 3, 1])

    if radio:
        affine[0, 0] *= -1
        affine[:3, 3] = [37, 20, 30]
        real          = np.flip(real, 0)
        imag          = np.flip(imag, 0)
    else:
        affine[:3, 3] = [10, 20, 30]

    comp   = real + imag * 1j
    abso   = np.abs(comp)
    phase  = np.arctan2(comp.imag, comp.real)
    comp4d = np.concatenate((comp.reshape(10, 10, 10, 1),
                             comp.reshape(10, 10, 10, 1)), 3)

    real   = nib.Nifti1Image(real,   affine)
    imag   = nib.Nifti1Image(imag,   affine)
    comp   = nib.Nifti1Image(comp,   affine)
    comp4d = nib.Nifti1Image(comp4d, affine)
    abso   = nib.Nifti1Image(abso,   affine)
    phase  = nib.Nifti1Image(phase,  affine)

    real  .set_qform(*real  .get_sform(coded=True))
    imag  .set_qform(*comp  .get_sform(coded=True))
    comp  .set_qform(*comp  .get_sform(coded=True))
    comp4d.set_qform(*comp4d.get_sform(coded=True))
    abso  .set_qform(*abso  .get_sform(coded=True))
    phase .set_qform(*phase .get_sform(coded=True))

    real  .to_filename('real_in.nii.gz')
    imag  .to_filename('imag_in.nii.gz')
    comp  .to_filename('complex_in.nii.gz')
    comp4d.to_filename('complex4d_in.nii.gz')
    abso  .to_filename('abs_in.nii.gz')
    phase .to_filename('phase_in.nii.gz')


def compare_images(file1, file2):
    img1  = nib.load(addExt(file1))
    img2  = nib.load(addExt(file2))
    data1 = np.asanyarray(img1.dataobj)
    data2 = np.asanyarray(img2.dataobj)
    assert img1.header['sform_code'] == img2.header['sform_code']
    assert img1.header['qform_code'] == img2.header['qform_code']
    assert np.all(np.isclose(img1.get_sform(),      img2.get_sform()))
    assert np.all(np.isclose(img1.get_qform(),      img2.get_qform()))
    assert np.all(np.isclose(img1.header['pixdim'], img2.header['pixdim']))
    assert np.all(np.isclose(data1,                 data2))


def test_fslcomplex_call(args, infiles, outfiles, radio):
    with tempdir():
        create_test_input_data(radio)
        sprun(f'fslcomplex {args} {infiles} {outfiles}')
        for outfile in outfiles.split(' '):

            if not imtest(outfile):
                continue

            prefix = outfile.split('_')[0]
            infile = f'{prefix}_in'

            print(f'Comparing {outfile} against {infile}')

            compare_images(infile, outfile)


if __name__ == '__main__':
    tests = [
        ('-realabs',       'complex_in',            'abs_out'),
        ('-realphase',     'complex_in',            'phase_out'),
        ('-realpolar',     'complex_in',            'abs_out phase_out'),
        ('-realcartesian', 'complex_in',            'real_out imag_out'),
        ('-complex',       'real_in imag_in',       'complex_out'),
        ('-complexpolar',  'abs_in phase_in',       'complex_out'),
        ('-copyonly',      'complex_in',            'complex_out'),
        ('-complexsplit',  'complex4d_in',          'complex_out 0 0'),
        ('-complexsplit',  'complex4d_in',          'complex_out 1 1'),
        ('-complexmerge',  'complex_in complex_in', 'complex4d_out'),
    ]

    result = 0
    for radio in [True, False]:
        for args, infiles, outfiles in tests:
            try:
                test_fslcomplex_call(args, infiles, outfiles, radio)
                print(f'\nTest fslcomplex {args} {infiles} {outfiles} [radio: {radio}] PASSED')
            except Exception as e:
                print(f'\nTest fslcomplex {args} {infiles} {outfiles} [radio: {radio}] FAILED')
                traceback.print_exc()
                result = 1

    sys.exit(result)