#!/usr/bin/env python # fsl_mrsi - wrapper script for MRSI fitting # # Author: Saad Jbabdi # William Carke # # Copyright (C) 2019 University of Oxford import time import warnings from fsl_mrs.auxiliary import configargparse from fsl_mrs import __version__ from fsl_mrs.utils.splash import splash from fsl_mrs.utils import fitting, misc, mrs_io, quantify # NOTE!!!! THERE ARE MORE IMPORTS IN THE CODE BELOW (AFTER ARGPARSING) def main(): # Parse command-line arguments p = configargparse.ArgParser( add_config_file_help=False, description="FSL Magnetic Resonance Spectroscopy Imaging" " Wrapper Script") p.add_argument('-v', '--version', action='version', version=__version__) required = p.add_argument_group('required arguments') fitting_args = p.add_argument_group('fitting options') optional = p.add_argument_group('additional options') # REQUIRED ARGUMENTS required.add_argument('--data', required=True, type=str, metavar='.NII', help='input NIFTI file') required.add_argument('--basis', required=True, type=str, metavar='', help='Basis file or folder') required.add_argument('--mask', required=True, type=str, metavar='', help='mask volume') required.add_argument('--output', required=True, type=str, metavar='', help='output folder') # FITTING ARGUMENTS fitting_args.add_argument('--algo', default='Newton', type=str, help='algorithm [Newton (fast, default)' ' or MH (slow)]') fitting_args.add_argument('--ignore', type=str, nargs='+', metavar='METAB', help='ignore certain metabolites [repeatable]') fitting_args.add_argument('--keep', type=str, nargs='+', metavar='METAB', help='only keep these metabolites') fitting_args.add_argument('--combine', type=str, nargs='+', action='append', metavar='METAB', help='combine certain metabolites [repeatable]') fitting_args.add_argument('--ppmlim', default=(.2, 4.2), type=float, nargs=2, metavar=('LOW', 'HIGH'), help='limit the fit to a freq range' ' (default=(.2, 4.2))') fitting_args.add_argument('--h2o', default=None, type=str, metavar='H2O', help='input .H2O file for quantification') fitting_args.add_argument('--baseline_order', default=2, type=int, metavar=('ORDER'), help='order of baseline polynomial' ' (default=2, -1 disables)') fitting_args.add_argument('--metab_groups', default=0, nargs='+', type=str_or_int_arg, help="metabolite groups: list of groups or list" " of names for indept groups.") fitting_args.add_argument('--add_MM', action="store_true", help="include default macromolecule peaks") fitting_args.add_argument('--lorentzian', action="store_true", help="Enable purely lorentzian broadening" " (default is Voigt)") fitting_args.add_argument('--ind_scale', default=None, type=str, nargs='+', help='List of basis spectra to scale' ' independently of other basis spectra.') fitting_args.add_argument('--disable_MH_priors', action="store_true", help="Disable MH priors.") # ADDITONAL OPTIONAL ARGUMENTS optional.add_argument('--TE', type=float, default=None, metavar='TE', help='Echo time for relaxation correction (ms)') optional.add_argument('--TR', type=float, default=None, metavar='TR', help='Repetition time for relaxation correction (s)') optional.add_argument('--tissue_frac', type=str, nargs=3, default=None, help='Tissue fraction nifti files registered to' ' MRSI volume. Supplied in order: WM, GM, CSF.') optional.add_argument('--internal_ref', type=str, default=['Cr', 'PCr'], nargs='+', help='Metabolite(s) used as an internal reference.' ' Defaults to tCr (Cr+PCr).') optional.add_argument('--wref_metabolite', type=str, default=None, help='Metabolite(s) used as an the reference for water scaling.' ' Uses internal defaults otherwise.') optional.add_argument('--ref_protons', type=int, default=None, help='Number of protons that reference metabolite is equivalent to.' ' No effect without setting --wref_metabolite.') optional.add_argument('--ref_int_limits', type=float, default=None, nargs=2, help='Reference spectrum integration limits (low, high).' ' No effect without setting --wref_metabolite.') optional.add_argument('--report', action="store_true", help='output html report') optional.add_argument('--verbose', action="store_true", help='spit out verbose info') optional.add_argument('--overwrite', action="store_true", help='overwrite existing output folder') optional.add_argument('--single_proc', action="store_true", help='Disable parallel processing of voxels.') optional.add_argument('--conj_fid', action="store_true", help='Force conjugation of FID') optional.add_argument('--no_conj_fid', action="store_true", help='Forbid automatic conjugation of FID') optional.add_argument('--conj_basis', action="store_true", help='Force conjugation of basis') optional.add_argument('--no_conj_basis', action="store_true", help='Forbid automatic conjugation of basis') optional.add_argument('--no_rescale', action="store_true", help='Forbid rescaling of FID/basis/H2O.') optional.add('--config', required=False, is_config_file=True, help='configuration file') # Parse command-line arguments args = p.parse_args() # Output kickass splash screen if args.verbose: splash(logo='mrsi') # ###################################################### # DO THE IMPORTS AFTER PARSING TO SPEED UP HELP DISPLAY import os import shutil import numpy as np from fsl_mrs.utils import report from fsl_mrs.core import NIFTI_MRS import datetime import nibabel as nib from functools import partial import multiprocessing as mp # ###################################################### # Check if output folder exists overwrite = args.overwrite if os.path.exists(args.output): if not overwrite: print(f"Folder '{args.output}' exists." " Are you sure you want to delete it? [Y,N]") response = input() overwrite = response.upper() == "Y" if not overwrite: print('Early stopping...') exit() else: shutil.rmtree(args.output) os.mkdir(args.output) else: os.mkdir(args.output) # Save chosen arguments with open(os.path.join(args.output, "options.txt"), "w") as f: f.write(str(args)) f.write("\n--------\n") f.write(p.format_values()) # ###### Do the work ####### # Read files mrsi_data = mrs_io.read_FID(args.data) if args.h2o is not None: H2O = mrs_io.read_FID(args.h2o) else: H2O = None mrsi = mrsi_data.mrs(basis_file=args.basis, ref_data=H2O) def loadNii(f): nii = np.asanyarray(nib.load(f).dataobj) if nii.ndim == 2: nii = np.expand_dims(nii, 2) return nii if args.mask is not None: mask = loadNii(args.mask) mrsi.set_mask(mask) if args.tissue_frac is not None: # WM, GM, CSF wm = loadNii(args.tissue_frac[0]) gm = loadNii(args.tissue_frac[1]) csf = loadNii(args.tissue_frac[2]) mrsi.set_tissue_seg(csf, wm, gm) # Set mrs output options from MRSI class object mrsi.conj_basis = args.conj_basis mrsi.no_conj_basis = args.no_conj_basis mrsi.conj_FID = args.conj_fid mrsi.no_conj_FID = args.no_conj_fid mrsi.rescale = not args.no_rescale mrsi.keep = args.keep mrsi.ignore = args.ignore # ppmlim for fitting ppmlim = args.ppmlim if ppmlim is not None: ppmlim = (ppmlim[0], ppmlim[1]) # Store info in disctionaries to be passed to MRS and fitting Fitargs = {'ppmlim': ppmlim, 'method': args.algo, 'baseline_order': args.baseline_order} if args.lorentzian: Fitargs['model'] = 'lorentzian' else: Fitargs['model'] = 'voigt' if args.disable_MH_priors: Fitargs['disable_mh_priors'] = True # Echo time if args.TE is not None: echotime = args.TE * 1E-3 elif 'TE' in mrsi.header: echotime = mrsi.header['TE'] else: echotime = None # Repetition time if args.TR is not None: repetition_time = args.TR elif 'RepetitionTime' in mrsi_data.hdr_ext: repetition_time = mrsi_data.hdr_ext['RepetitionTime'] else: repetition_time = None # Fitting if args.verbose: print('\n--->> Start fitting\n\n') print(f' Algorithm = [{args.algo}]\n') # Initialise by fitting the average FID across all voxels if args.verbose: print(" Initialise with average fit") mrs = mrsi.mrs_from_average() Fitargs_init = Fitargs.copy() Fitargs_init['method'] = 'Newton' res_init, _ = runvoxel([mrs, 0, None], args, Fitargs_init, echotime, repetition_time) Fitargs['x0'] = res_init.params # quick summary figure report.fitting_summary_fig( mrs, res_init, filename=os.path.join(args.output, 'fit_avg.png')) # Create interactive HTML report if args.report: report.create_report( mrs, res_init, filename=os.path.join(args.output, 'report.html'), fidfile=args.data, basisfile=args.basis, h2ofile=args.h2o, date=datetime.datetime.now().strftime("%Y-%m-%d %H:%M")) warnings.filterwarnings("ignore") if args.single_proc: if args.verbose: print(' Running sequentially (are you sure about that?) ') results = [] for idx, mrs in enumerate(mrsi): res = runvoxel(mrs, args, Fitargs, echotime, repetition_time) results.append(res) if args.verbose: print(f'{idx+1}/{mrsi.num_masked_voxels} voxels completed') else: if args.verbose: print(f' Parallelising over {mp.cpu_count()} workers ') func = partial(runvoxel, args=args, Fitargs=Fitargs, echotime=echotime, repetition_time=repetition_time) with mp.Pool(processes=mp.cpu_count()) as p: results = p.map_async(func, mrsi) if args.verbose: track_job(results) results = results.get() # Save output files if args.verbose: print(f'--->> Saving output files to {args.output}\n') # Results --> Images # Store concentrations, uncertainties, residuals, predictions # Output: # 1) concs - folders with N_metabs x 3D nifti for each scaling # (raw,internal,molarity,molality) # 2) uncertainties - N_metabs x 3D nifti as percentage # 3) qc - 2 x N_metabs x 3D nifti for SNR and fwhm # 4) fit - predicted, residuals and baseline? # Generate the folders concs_folder = os.path.join(args.output, 'concs') uncer_folder = os.path.join(args.output, 'uncertainties') qc_folder = os.path.join(args.output, 'qc') fit_folder = os.path.join(args.output, 'fit') os.mkdir(concs_folder) os.mkdir(uncer_folder) os.mkdir(qc_folder) os.mkdir(fit_folder) # Extract concentrations indicies = [res[1] for res in results] scalings = ['raw'] if results[0][0].concScalings['internal'] is not None: scalings.append('internal') if results[0][0].concScalings['molarity'] is not None: scalings.append('molarity') if results[0][0].concScalings['molality'] is not None: scalings.append('molality') def save_img_output(fname, data): if data.ndim > 3 and data.shape[3] == mrsi.FID_points: NIFTI_MRS(data, header=mrsi_data.header).save(fname) else: img = nib.Nifti1Image(data, mrsi_data.voxToWorldMat) nib.save(img, fname) metabs = results[0][0].metabs for scale in scalings: cur_fldr = os.path.join(concs_folder, scale) os.mkdir(cur_fldr) for metab in metabs: metab_conc_list = [res[0].getConc(scaling=scale, metab=metab) for res in results] file_nm = os.path.join(cur_fldr, metab + '.nii.gz') save_img_output(file_nm, mrsi.list_to_matched_array( metab_conc_list, indicies=indicies, cleanup=True, dtype=float)) # Uncertainties for metab in results[0][0].metabs: metab_sd_list = [res[0].getUncertainties(metab=metab) for res in results] file_nm = os.path.join(uncer_folder, metab + '_sd.nii.gz') save_img_output(file_nm, mrsi.list_to_matched_array( metab_sd_list, indicies=indicies, cleanup=True, dtype=float)) # qc - SNR & FWHM for metab in results[0][0].original_metabs: metab_fwhm_list = [res[0].getQCParams(metab=metab)[1] for res in results] file_nm = os.path.join(qc_folder, metab + '_fwhm.nii.gz') save_img_output(file_nm, mrsi.list_to_matched_array( metab_fwhm_list, indicies=indicies, cleanup=True, dtype=float)) metab_snr_list = [res[0].getQCParams(metab=metab)[0] for res in results] file_nm = os.path.join(qc_folder, metab + '_snr.nii.gz') save_img_output(file_nm, mrsi.list_to_matched_array( metab_snr_list, indicies=indicies, cleanup=True, dtype=float)) # fit # TODO: check if data has been conjugated, if so conjugate the predictions mrs_scale = mrsi.get_scalings_in_order() pred_list = [] for res, scale in zip(results, mrs_scale): pred_list.append(res[0].pred / scale['FID']) file_nm = os.path.join(fit_folder, 'fit.nii.gz') save_img_output(file_nm, mrsi.list_to_matched_array( pred_list, indicies=indicies, cleanup=False, dtype=np.complex64)) res_list = [] for res, scale in zip(results, mrs_scale): res_list.append(res[0].residuals / scale['FID']) file_nm = os.path.join(fit_folder, 'residual.nii.gz') save_img_output(file_nm, mrsi.list_to_matched_array( res_list, indicies=indicies, cleanup=False, dtype=np.complex64)) baseline_list = [] for res, scale in zip(results, mrs_scale): baseline_list.append(res[0].baseline / scale['FID']) file_nm = os.path.join(fit_folder, 'baseline.nii.gz') save_img_output(file_nm, mrsi.list_to_matched_array( baseline_list, indicies=indicies, cleanup=False, dtype=np.complex64)) if args.verbose: print('\n\n\nDone.') def runvoxel(mrs_in, args, Fitargs, echotime, repetition_time): mrs, index, tissue_seg = mrs_in # Parse metabolite groups metab_groups = misc.parse_metab_groups(mrs, args.metab_groups) if args.add_MM: n = mrs.add_MM_peaks(gamma=40, sigma=30) new_metab_groups = [i + max(metab_groups) + 1 for i in range(n)] new_metab_groups = metab_groups + new_metab_groups else: new_metab_groups = metab_groups.copy() res = fitting.fit_FSLModel(mrs, **Fitargs, metab_groups=new_metab_groups) # Internal and Water quantification if requested if (mrs.H2O is None) or (echotime is None) or (repetition_time is None): if echotime is None: warnings.warn('H2O file provided but could not determine TE:' ' no absolute quantification will be performed.', UserWarning) if repetition_time is None: warnings.warn('H2O file provided but could not determine TR:' ' no absolute quantification will be performed.', UserWarning) res.calculateConcScaling(mrs, internal_reference=args.internal_ref, verbose=args.verbose) else: # Form quantification information q_info = quantify.QuantificationInfo(echotime, repetition_time, mrs.names, mrs.centralFrequency / 1E6, water_ref_metab=args.wref_metabolite, water_ref_metab_protons=args.ref_protons, water_ref_metab_limits=args.ref_int_limits) if tissue_seg: q_info.set_fractions(tissue_seg) if args.h2o_scale: q_info.add_corr = args.h2o_scale res.calculateConcScaling(mrs, quant_info=q_info, internal_reference=args.internal_ref, verbose=args.verbose) # Combine metabolites. if args.combine is not None: res.combine(args.combine) return res, index def str_or_int_arg(x): try: return int(x) except ValueError: return x class PoolProgress: def __init__(self, pool, update_interval=3): self.pool = pool self.update_interval = update_interval def track(self, job): task = self.pool._cache[job._job] while task._number_left > 0: print("Voxels remaining = {0}". format(task._number_left * task._chunksize)) time.sleep(self.update_interval) def track_job(job, update_interval=3): while job._number_left > 0: print(f" {job._number_left * job._chunksize} Voxels remaining ", end='\r', flush=True) time.sleep(update_interval) if __name__ == '__main__': main()