data_evaluation_utils.py 45.3 KB
Newer Older
1
2
3
4
"""Data Evaluation Functions

Description:

5
    This folder contains several functions which, either on their own or included in larger pieces of software, perform data evaluation tasks.
6

7
8
9
10
11
Usage:

    To use content from this folder, import the functions and instantiate them as you wish to use them:

        from utils.data_evaluation_utils import function_name
12

13
14
TODO: Might be worth adding some information on uncertaintiy estimation, later down the line

15
16
17
"""

import os
18
import pickle
19
20
import numpy as np
import torch
21
import logging
22
import h5py
23
import utils.data_utils as data_utils
24
from utils.common_utils import create_folder
Andrei Roibu's avatar
Andrei Roibu committed
25
import pandas as pd
26
from fsl.data.image import Image
27
from fsl.utils.image.roi import roi
Andrei Roibu's avatar
Andrei Roibu committed
28
import itertools
29
30
from scipy.spatial.distance import cosine
from scipy.stats import pearsonr, spearmanr
31
32

log = logging.getLogger(__name__)
33

34
35
36
37
38
39
40
def evaluate_data(trained_model_path,
                     data_directory,
                     mapping_data_file,
                     mapping_targets_file,
                     data_list,
                     prediction_output_path,
                     prediction_output_database_name,
41
                     prediction_output_statistics_name,
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
                     brain_mask_path,
                     dmri_mean_mask_path,
                     rsfmri_mean_mask_path,
                     regression_factors,
                     mean_regression_flag,
                     mean_regression_all_flag, 
                     mean_subtraction_flag,
                     scale_volumes_flag,
                     normalize_flag,
                     minus_one_scaling_flag,
                     negative_flag, 
                     outlier_flag,
                     shrinkage_flag,
                     hard_shrinkage_flag,
                     crop_flag,
                     device=0, 
                     exit_on_error=False,
59
                     output_database_flag=False,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
                     cross_domain_x2x_flag=False,
                     cross_domain_y2y_flag=False,
                     mode='evaluate'):
                     
    """Model Evaluator

    This function generates the rsfMRI arrays for the given inputs

    Args:
        trained_model_path (str): Path to the location of the trained model
        data_directory (str): Path to input data directory
        mapping_data_file (str): Path to the input file
        mapping_targets_file (str): Path to the target file
        data_list (str): Path to a .txt file containing the input files for consideration
        prediction_output_path (str): Output prediction path
        prediction_output_database_name (str): Name of the output database
76
        prediction_output_statistics_name (str): Name of the output statistics database
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        brain_mask_path (str): Path to the MNI brain mask file
        dmri_mean_mask_path (str): Path to the dualreg subject mean mask
        rsfmri_mean_mask_path (str): Path to the summed tract mean mask
        regression_factors (str): Path to the linear regression weights file
        mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
        mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
        scale_volumes_flag (bool): Flag indicating if the volumes should be scaled.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
        device (str/int): Device type used for training (int - GPU id, str- CPU)
        mode (str): Current run mode or phase
        exit_on_error (bool): Flag that triggers the raising of an exception
95
        output_database_flag (bool): Flag indicating if the output maps should be saved to hdf5 database
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets

    Raises:
        FileNotFoundError: Error in reading the provided file!
        Exception: Error code execution!
    """

    log.info(
        "Started Evaluation. Check tensorboard for plots (if a LogWriter is provided)")

    with open(data_list) as data_list_file:
        volumes_to_be_used = data_list_file.read().splitlines()

    # Test if cuda is available and attempt to run on GPU

    cuda_available = torch.cuda.is_available()
    if type(device) == int:
        if cuda_available:
            model = torch.load(trained_model_path)
            torch.cuda.empty_cache()
            model.cuda(device)
        else:
            log.warning(
                "CUDA not available. Switching to CPU. Investigate behaviour!")
            device = 'cpu'

    if (type(device) == str) or not cuda_available:
        model = torch.load(trained_model_path,
                           map_location=torch.device(device))

    model.eval()

    # Create the prediction path folder if this is not available

    create_folder(prediction_output_path)

    # Initiate the evaluation

    log.info("rsfMRI Generation Started")
    if cross_domain_y2y_flag == True:
        # If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
138
139
140
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file, mapping_targets_file=mapping_targets_file)
    elif cross_domain_x2x_flag == True:
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file=mapping_data_file)
141
    else:
142
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file)
143

144
145
146
147
148
149
150
151
    if output_database_flag == True:
        output_database_path = os.path.join(prediction_output_path, prediction_output_database_name)
        if os.path.exists(output_database_path):
            os.remove(output_database_path)
        output_database_handle = h5py.File(output_database_path, 'w')

    output_statistics = {}
    output_statistics_path = os.path.join(prediction_output_path, prediction_output_statistics_name)
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    with torch.no_grad():

        for volume_index, file_path in enumerate(file_paths):
            try:
                print("Mapping Volume {}/{}".format(volume_index+1, len(file_paths)))
                # Generate volume & header

                subject = volumes_to_be_used[volume_index]

                predicted_complete_volume, predicted_volume, header, xform = _generate_volume_map(file_path,
                                                                                                subject,
                                                                                                model,
                                                                                                device,
                                                                                                cuda_available,
                                                                                                brain_mask_path,
                                                                                                dmri_mean_mask_path,
                                                                                                rsfmri_mean_mask_path,
                                                                                                regression_factors,
                                                                                                mean_regression_flag,
                                                                                                mean_regression_all_flag, 
                                                                                                mean_subtraction_flag,
                                                                                                scale_volumes_flag,
                                                                                                normalize_flag,
                                                                                                minus_one_scaling_flag,
                                                                                                negative_flag, 
                                                                                                outlier_flag,
                                                                                                shrinkage_flag,
                                                                                                hard_shrinkage_flag,
                                                                                                crop_flag,
                                                                                                cross_domain_x2x_flag,
                                                                                                cross_domain_y2y_flag)

Andrei Roibu's avatar
Andrei Roibu committed
185
186
187
                if crop_flag == True:
                    predicted_volume = roi(Image(predicted_volume, header=header), ((-9,82),(-10,99),(0,91))).data

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                target_volume = _generate_target_volume(file_path,
                                                        subject,
                                                        dmri_mean_mask_path,
                                                        rsfmri_mean_mask_path,
                                                        regression_factors,
                                                        mean_regression_flag,
                                                        mean_regression_all_flag, 
                                                        mean_subtraction_flag,
                                                        crop_flag, 
                                                        cross_domain_x2x_flag
                                                        )

                mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b = _statistics_calculator(predicted_volume, target_volume)
                output_statistics[subject] = [mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b]

                if output_database_flag == True:
                    group = output_database_handle.create_group(subject)
                    group.create_dataset('predicted_complete_volume', data=predicted_complete_volume)
                    group.create_dataset('predicted_volume', data=predicted_volume)
                    group.create_dataset('header', data=header)
                    group.create_dataset('xform', data=xform)
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

                log.info("Processed: " + volumes_to_be_used[volume_index] + " " + str(
                    volume_index + 1) + " out of " + str(len(volumes_to_be_used)))

                print("Mapped Volumes saved in: ", prediction_output_path)

            except FileNotFoundError as exception_expression:
                log.error("Error in reading the provided file!")
                log.exception(exception_expression)
                if exit_on_error:
                    raise(exception_expression)

            except Exception as exception_expression:
                log.error("Error code execution!")
                log.exception(exception_expression)
                if exit_on_error:
                    raise(exception_expression)

227
        output_statistics_df = pd.DataFrame.from_dict(output_statistics, orient='index', columns=['mse', 'mae', 'cel', 'pearson_r', 'pearson_p', 'spearman_r', 'spearman_p', 'regression_w', 'regression_b'])     
228
        output_statistics_df.to_csv(output_statistics_path)
229

230
231
    log.info("Output Data Generation Complete")

232
233
    if output_database_flag == True:
        output_database_handle.close()
234

235

236
def evaluate_mapping(trained_model_path,
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
237
238
                     data_directory,
                     mapping_data_file,
239
                     mapping_targets_file,
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
240
241
                     data_list,
                     prediction_output_path,
242
                     brain_mask_path,
Andrei Roibu's avatar
Andrei Roibu committed
243
244
245
                     dmri_mean_mask_path,
                     rsfmri_mean_mask_path,
                     regression_factors,
246
247
248
249
250
                     mean_regression_flag,
                     mean_regression_all_flag, 
                     mean_subtraction_flag,
                     scale_volumes_flag,
                     normalize_flag,
Andrei Roibu's avatar
Andrei Roibu committed
251
                     minus_one_scaling_flag,
252
253
254
255
256
257
258
                     negative_flag, 
                     outlier_flag,
                     shrinkage_flag,
                     hard_shrinkage_flag,
                     crop_flag,
                     device=0, 
                     exit_on_error=False,
259
260
                     cross_domain_x2x_flag=False,
                     cross_domain_y2y_flag=False,
261
                     mode='evaluate'):
262
263
264
265
266
267
268
269
    """Model Evaluator

    This function generates the rsfMRI map for an input running on on a single axis or path

    Args:
        trained_model_path (str): Path to the location of the trained model
        data_directory (str): Path to input data directory
        mapping_data_file (str): Path to the input file
270
        mapping_targets_file (str): Path to the target file
271
272
        data_list (str): Path to a .txt file containing the input files for consideration
        prediction_output_path (str): Output prediction path
273
        brain_mask_path (str): Path to the MNI brain mask file
Andrei Roibu's avatar
Andrei Roibu committed
274
275
276
        dmri_mean_mask_path (str): Path to the dualreg subject mean mask
        rsfmri_mean_mask_path (str): Path to the summed tract mean mask
        regression_factors (str): Path to the linear regression weights file
277
278
279
280
281
        mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
        mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
        scale_volumes_flag (bool): Flag indicating if the volumes should be scaled.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
Andrei Roibu's avatar
Andrei Roibu committed
282
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
283
284
285
286
287
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
288
289
290
        device (str/int): Device type used for training (int - GPU id, str- CPU)
        mode (str): Current run mode or phase
        exit_on_error (bool): Flag that triggers the raising of an exception
291
292
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

    Raises:
        FileNotFoundError: Error in reading the provided file!
        Exception: Error code execution!
    """

    log.info(
        "Started Evaluation. Check tensorboard for plots (if a LogWriter is provided)")

    with open(data_list) as data_list_file:
        volumes_to_be_used = data_list_file.read().splitlines()

    # Test if cuda is available and attempt to run on GPU

    cuda_available = torch.cuda.is_available()
    if type(device) == int:
        if cuda_available:
            model = torch.load(trained_model_path)
            torch.cuda.empty_cache()
            model.cuda(device)
        else:
            log.warning(
                "CUDA not available. Switching to CPU. Investigate behaviour!")
            device = 'cpu'

    if (type(device) == str) or not cuda_available:
        model = torch.load(trained_model_path,
                           map_location=torch.device(device))

    model.eval()

    # Create the prediction path folder if this is not available

326
    create_folder(prediction_output_path)
327
328
329
330

    # Initiate the evaluation

    log.info("rsfMRI Generation Started")
331
332
333
334
335
    if cross_domain_y2y_flag == True:
        # If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file)
    else:
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file)
336
337
338
339
340

    with torch.no_grad():

        for volume_index, file_path in enumerate(file_paths):
            try:
341
                print("Mapping Volume {}/{}".format(volume_index+1, len(file_paths)))
342
                # Generate volume & header
Andrei Roibu's avatar
Andrei Roibu committed
343
344
345

                subject = volumes_to_be_used[volume_index]

346
347
348
349
350
351
352
353
354
355
356
357
358
359
                predicted_complete_volume, predicted_volume, header, xform = _generate_volume_map(file_path,
                                                                                                subject,
                                                                                                model,
                                                                                                device,
                                                                                                cuda_available,
                                                                                                brain_mask_path,
                                                                                                dmri_mean_mask_path,
                                                                                                rsfmri_mean_mask_path,
                                                                                                regression_factors,
                                                                                                mean_regression_flag,
                                                                                                mean_regression_all_flag, 
                                                                                                mean_subtraction_flag,
                                                                                                scale_volumes_flag,
                                                                                                normalize_flag,
Andrei Roibu's avatar
Andrei Roibu committed
360
                                                                                                minus_one_scaling_flag,
361
362
363
364
                                                                                                negative_flag, 
                                                                                                outlier_flag,
                                                                                                shrinkage_flag,
                                                                                                hard_shrinkage_flag,
365
366
367
                                                                                                crop_flag,
                                                                                                cross_domain_x2x_flag,
                                                                                                cross_domain_y2y_flag)
368
369

                if crop_flag == False:
370
371
372
373
                    if cross_domain_y2y_flag == True:
                        output_nifti_image = Image(predicted_volume, header=header)
                    else:
                        output_nifti_image = Image(predicted_volume, header=header, xform=xform)
374
375
376
                elif crop_flag == True:
                    output_nifti_image = Image(predicted_volume, header=header)
                    output_nifti_image = roi(output_nifti_image, ((-9,82),(-10,99),(0,91)))
377
378
379
380
381
382
383
384
385

                output_nifti_path = os.path.join(
                    prediction_output_path, volumes_to_be_used[volume_index])

                if '.nii' not in output_nifti_path:
                    output_nifti_path += '.nii.gz'

                output_nifti_image.save(output_nifti_path)

386
387
                if mean_regression_flag == True:
                    if crop_flag == False:
388
389
390
391
                        if cross_domain_y2y_flag == True:
                            output_nifti_image = Image(predicted_complete_volume, header=header)
                        else:
                            output_complete_nifti_image = Image(predicted_complete_volume, header=header, xform=xform)
392
393
394
                    elif crop_flag == True:
                        output_complete_nifti_image = Image(predicted_complete_volume, header=header)
                        output_complete_nifti_image = roi(output_complete_nifti_image, ((-9,82),(-10,99),(0,91)))
395

396
397
                    output_complete_nifti_path = os.path.join(
                        prediction_output_path, volumes_to_be_used[volume_index]) + '_complete'
398
399
400
401

                    if '.nii' not in output_complete_nifti_path:
                        output_complete_nifti_path += '.nii.gz'

402
403
                    output_complete_nifti_image.save(
                        output_complete_nifti_path)
404

405
406
407
                log.info("Processed: " + volumes_to_be_used[volume_index] + " " + str(
                    volume_index + 1) + " out of " + str(len(volumes_to_be_used)))

408
409
                print("Mapped Volumes saved in: ", prediction_output_path)

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
            except FileNotFoundError as exception_expression:
                log.error("Error in reading the provided file!")
                log.exception(exception_expression)
                if exit_on_error:
                    raise(exception_expression)

            except Exception as exception_expression:
                log.error("Error code execution!")
                log.exception(exception_expression)
                if exit_on_error:
                    raise(exception_expression)

    log.info("rsfMRI Generation Complete")


Andrei Roibu's avatar
Andrei Roibu committed
425
426
427
428
429
430
431
432
433
def _generate_volume_map(file_path,
                         subject,
                         model,
                         device,
                         cuda_available,
                         brain_mask_path,
                         dmri_mean_mask_path,
                         rsfmri_mean_mask_path,
                         regression_factors,
434
435
436
437
438
                         mean_regression_flag,
                         mean_regression_all_flag, 
                         mean_subtraction_flag,
                         scale_volumes_flag,
                         normalize_flag,
Andrei Roibu's avatar
Andrei Roibu committed
439
                         minus_one_scaling_flag,
440
441
442
443
444
                         negative_flag, 
                         outlier_flag,
                         shrinkage_flag,
                         hard_shrinkage_flag,
                         crop_flag,
445
446
                         cross_domain_x2x_flag,
                         cross_domain_y2y_flag
447
                         ):
448
    """Output Volume Generator
449
450
451
452
453

    This function uses the trained model to generate a new volume

    Args:
        file_path (str): Path to the desired file
Andrei Roibu's avatar
Andrei Roibu committed
454
        subject (str): Subject ID of the subject volume to be regressed
455
456
457
        model (class): BrainMapper model class
        device (str/int): Device type used for training (int - GPU id, str- CPU)
        cuda_available (bool): Flag indicating if a cuda-enabled GPU is present
458
        brain_mask_path (str): Path to the MNI brain mask file
Andrei Roibu's avatar
Andrei Roibu committed
459
460
461
        dmri_mean_mask_path (str): Path to the group mean volume
        rsfmri_mean_mask_path (str): Path to the dualreg subject mean mask
        regression_factors (str): Path to the linear regression weights file
462
463
464
465
466
        mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
        mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
        scale_volumes_flag (bool): Flag indicating if the volumes should be scaled.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
Andrei Roibu's avatar
Andrei Roibu committed
467
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
468
469
470
471
472
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
473
474
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
475
476
477
478
479
480

    Returns
        predicted_volume (np.array): Array containing the information regarding the generated volume
        header (class): 'nibabel.nifti1.Nifti1Header' class object, containing volume metadata
    """

481
    volume, header, xform = data_utils.load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag)
482

483
484
    if mean_regression_flag == True:
        if mean_regression_all_flag == True:
485
            volume = _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag)
Andrei Roibu's avatar
Andrei Roibu committed
486
            scaling_parameters = [-0.0626, 0.1146, -14.18, 16.9475]
487
        else:
Andrei Roibu's avatar
Andrei Roibu committed
488
            scaling_parameters = [0.0, 0.2, -14.18, 16.9475]
489
490
    elif mean_subtraction_flag == True:
        scaling_parameters = [0.0, 0.2, 0.0, 10.0]
Andrei Roibu's avatar
Andrei Roibu committed
491

492
493
    print('volume range:', np.min(volume), np.max(volume))

494
    if scale_volumes_flag == True:
495
        volume = _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag)
Andrei Roibu's avatar
Andrei Roibu committed
496

497
    if len(volume.shape) == 5:
498
499
500
501
502
503
        volume = volume
    else:
        volume = volume[np.newaxis, np.newaxis, :, :, :]

    volume = torch.tensor(volume).type(torch.FloatTensor)

504
505
    if cuda_available and (type(device) == int):
        volume = volume.cuda(device)
506

507
    output = model(volume)
508
509
    output = (output.cpu().numpy()).astype('float32')
    output = np.squeeze(output)
Andrei Roibu's avatar
Andrei Roibu committed
510

511
512
    print('output range:', np.min(output), np.max(output))

513
    output = _rescale_output(output, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_x2x_flag)
514

515
516
    print('output rescaled:', np.min(output), np.max(output))

517
518
519
520
521
522
    if crop_flag == False:
        MNI152_T1_2mm_brain_mask = Image(brain_mask_path).data
    elif crop_flag == True:
        MNI152_T1_2mm_brain_mask = roi(Image(brain_mask_path),((9,81),(10,100),(0,77))).data

    if mean_regression_flag == True or mean_subtraction_flag == True:
523

524
525
526
527
528
529
530
531
532
533
        if cross_domain_x2x_flag == True:
            if crop_flag == False:
                mean_mask = Image(dmri_mean_mask_path).data
            elif crop_flag == True:
                mean_mask = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
        else:
            if crop_flag == False:
                mean_mask = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
            elif crop_flag == True:
                mean_mask = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
534
535

        if mean_regression_flag == True:
536
537
538
539
            if cross_domain_x2x_flag == True:
                weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
            else:
                weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
540
            predicted_complete_volume = np.add(output, np.multiply(weight, mean_mask))
541
542

        if mean_subtraction_flag == True:
543
            predicted_complete_volume = np.add(output, mean_mask)
544

Andrei Roibu's avatar
Andrei Roibu committed
545
546
        print('predicted_complete_volume', np.min(
            predicted_complete_volume), np.max(predicted_complete_volume))
547

Andrei Roibu's avatar
Andrei Roibu committed
548
549
        predicted_complete_volume = np.multiply(
            predicted_complete_volume, MNI152_T1_2mm_brain_mask)
550

Andrei Roibu's avatar
Andrei Roibu committed
551
552
        print('predicted_complete_volume masked:', np.min(
            predicted_complete_volume), np.max(predicted_complete_volume))
553

554
555
    else:
        predicted_complete_volume = None
556

557
    predicted_volume = np.multiply(output, MNI152_T1_2mm_brain_mask)
558

Andrei Roibu's avatar
Andrei Roibu committed
559
560
    print('predicted_volume masked:', np.min(
        predicted_volume), np.max(predicted_volume))
561

562
    return predicted_complete_volume, predicted_volume, header, xform
563
564


565
def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag):
566
567
568
569
570
    """Input Scaling

    This function reads the scaling factors from the saved file and then scales the data.

    Args:
571
572
573
        volume (np.array): Numpy array representing the un-scalled volume. 
        scaling_parameters (list): List of scaling parameters.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
Andrei Roibu's avatar
Andrei Roibu committed
574
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
575
576
577
578
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
579
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
580
581
582
583
584

    Returns:
        scaled_volume (np.array): Scaled volume
    """

585
586
587
588
    if cross_domain_y2y_flag == True:
        _, _, min_value, max_value = scaling_parameters
    else:
        min_value, max_value, _, _ = scaling_parameters
589
590

    if shrinkage_flag == True:
591
592
593
594
        if cross_domain_y2y_flag == True:
            lambd = 3.0
        else:
            lambd = 0.003 # Hard coded, equivalent to tht 1p and 99p values across the whole population in UKBB
595
596
597
598
599
600
601

        if hard_shrinkage_flag == True:
            volume = _hard_shrinkage(volume, lambd)
        elif hard_shrinkage_flag == False:
            volume = _soft_shrinkage(volume, lambd)
            min_value += lambd
            max_value -= lambd
602

603
604
605
606
607
608
609
    if negative_flag == True:
        volume[volume < 0.0] = 0.0
        min_value = 0.0

    if outlier_flag == True:
        volume[volume > max_value] = max_value
        volume[volume < min_value] = min_value
610

611
612
613
    if normalize_flag == True:
        # Normalization to [0, 1]
        scaled_volume = np.divide(np.subtract(volume, min_value), np.subtract(max_value, min_value))
Andrei Roibu's avatar
Andrei Roibu committed
614
    elif minus_one_scaling_flag == True:
615
616
        # Scaling between [-1, 1]
        scaled_volume = np.add(-1.0, np.multiply(2.0, np.divide(np.subtract(volume, min_value), np.subtract(max_value, min_value))))
Andrei Roibu's avatar
Andrei Roibu committed
617
618
    # Else, no scaling occus, but the other flags can still hold true if the scaling flag is true! 

619
620
621
    return scaled_volume


622
def _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag):
Andrei Roibu's avatar
Andrei Roibu committed
623
624
625
626
627
628
629
630
631
632
    """ Inputn Regression

    This function regresse the group mean from the input volume using the saved regression weights.

    TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.

    Args:
        volume (np.array): Unregressed volume
        subject (str): Subject ID of the subject volume to be regressed
        dmri_mean_mask_path (str): Path to the group mean volume
633
        rsfmri_mean_mask_path (str): Path to the target group mean volume
Andrei Roibu's avatar
Andrei Roibu committed
634
        regression_factors (str): Path to the linear regression weights file
635
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
636
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Andrei Roibu's avatar
Andrei Roibu committed
637
638
639
640
641
642

    Returns:
        regressed_volume (np.array): Linear regressed volume

    """

643
644
645
646
647
648
649
650
651
652
653
654
    if cross_domain_y2y_flag == True:
        weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI'] 
        if crop_flag == False:
            group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
        elif crop_flag == True:
            group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
    else:
        weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
        if crop_flag == False:
            group_mean = Image(dmri_mean_mask_path).data
        elif crop_flag == True:
            group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
655

656
    regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
Andrei Roibu's avatar
Andrei Roibu committed
657
658
659

    return regressed_volume

Andrei Roibu's avatar
Andrei Roibu committed
660

661
def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_x2x_flag):
662
663
664
665
666
667
    """Output Rescaling

    This function reads the scaling factors from the saved file and then scales the data.

    Args:
        volume (np.array): Unscalled volume
668
669
        scaling_parameters (list): List of scaling parameters.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
Andrei Roibu's avatar
Andrei Roibu committed
670
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
671
672
673
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
674
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
675
676
677
678
679

    Returns:
        rescaled_volume (np.array): Rescaled volume
    """

680
681
682
683
    if cross_domain_x2x_flag == True:
        min_value, max_value, _, _ = scaling_parameters
    else:
        _, _, min_value, max_value = scaling_parameters
684
685

    if shrinkage_flag == True:
686
687
688
689
690
        if cross_domain_x2x_flag == True:
            lambd = 0.003
        else:
            lambd = 3.0
        
691
692
693
694
695
        if hard_shrinkage_flag == True:
            pass
        elif hard_shrinkage_flag == False:
            min_value += lambd
            max_value -= lambd
696

697
698
    if negative_flag == True:
        min_value = 0.0
699

700
701
702
    if normalize_flag == True:
        # Normalization to [0, 1]
        rescaled_volume = np.add(np.multiply(volume, np.subtract(max_value, min_value)), min_value)
Andrei Roibu's avatar
Andrei Roibu committed
703
    elif minus_one_scaling_flag == True:
704
705
        # Scaling between [-1, 1]
        rescaled_volume = np.add(np.multiply(np.divide(np.add(volume, 1), 2), np.subtract(max_value, min_value)), min_value)
Andrei Roibu's avatar
Andrei Roibu committed
706
707
    # Else, no rescaling occus, but the other flags can still hold true if the scaling flag is true! 

708
709
710
    return rescaled_volume


711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
def _hard_shrinkage(volume, lambd):
    """ Hard Shrinkage

    This function performs a hard shrinkage on the volumes.
    volume = { x , x > lambd | x < -lambd
                0 , x e [-lambd, lambd]
                }

    Args:
        volume (np.array): Unshrunken volume
        lambd (float): Threshold parameter
    
    Returns:
        volume (np.array) : Hard shrunk volume
    """

    volume[np.where(np.logical_and(volume>-lambd, volume<lambd))] = 0

    return volume

731

732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
def _soft_shrinkage(volume, lambd):
    """ Soft Shrinkage

    This function performs a soft shrinkage on the volumes.
    volume = { x + lambd , x < -lambd
                0         , x e [-lambd, lambd]
                x - lambd , x > lambd
                }

    Args:
        volume (np.array): Unshrunken volume
        lambd (float): Threshold parameter
    
    Returns:
        volume (np.array) : Soft shrunk volume
    """

    volume[np.where(np.logical_and(volume>=-lambd, volume<=lambd))] = 0.0
    volume[volume < -lambd] = volume[volume < -lambd] + lambd
    volume[volume > lambd] = volume[volume > lambd] - lambd

    return volume


756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
def _generate_target_volume(file_path,
                            subject,
                            dmri_mean_mask_path,
                            rsfmri_mean_mask_path,
                            regression_factors,
                            mean_regression_flag,
                            mean_regression_all_flag, 
                            mean_subtraction_flag,
                            crop_flag, 
                            cross_domain_x2x_flag
                            ):
    """Target Volume Generator

    This function loads and preprocesses a target volume for comparing with the network predicted volumes

    Args:
        file_path (str): Path to the desired file
        subject (str): Subject ID of the subject volume to be regressed
        dmri_mean_mask_path (str): Path to the group mean volume
        rsfmri_mean_mask_path (str): Path to the dualreg subject mean mask
        regression_factors (str): Path to the linear regression weights file
        mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
        mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs

    Returns:
        volume (np.array): Array containing the information regarding the target volume
    """

787
    volume = data_utils.load_and_preprocess_targets(file_path, cross_domain_x2x_flag, crop_flag=False)
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910

    if mean_regression_flag == True:
        volume = _regress_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_x2x_flag, mean_regression_all_flag)
    elif mean_subtraction_flag == True:
        volume = _subtract_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, crop_flag, cross_domain_x2x_flag)

    return volume


def _regress_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_x2x_flag, mean_regression_all_flag):
    """ Target Regression

    This function regresse the group mean from the target volume using the saved regression weights.

    TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.

    Args:
        volume (np.array): Unregressed volume
        subject (str): Subject ID of the subject volume to be regressed
        dmri_mean_mask_path (str): Path to the group mean volume
        rsfmri_mean_mask_path (str): Path to the target group mean volume
        regression_factors (str): Path to the linear regression weights file
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.

    Returns:
        regressed_volume (np.array): Linear regressed volume

    """

    if cross_domain_x2x_flag == True:
        if mean_regression_all_flag == True:
            weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
            if crop_flag == False:
                group_mean = Image(dmri_mean_mask_path).data
            elif crop_flag == True:
                group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
            
            regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
        else:
            regressed_volume = volume
    else:
        weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI'] 
        if crop_flag == False:
            group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
        elif crop_flag == True:
            group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]

        regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))

    return regressed_volume


def _subtract_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, crop_flag, cross_domain_x2x_flag):
    """ Target Subtraction

    This function subtracts the group mean from the target volume using the saved regression weights.

    TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.

    Args:
        volume (np.array): Unregressed volume
        subject (str): Subject ID of the subject volume to be regressed
        dmri_mean_mask_path (str): Path to the group mean volume
        rsfmri_mean_mask_path (str): Path to the target group mean volume
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets

    Returns:
        regressed_volume (np.array): Linear regressed volume

    """

    if cross_domain_x2x_flag == True:
        subtracted_volume = volume
    else:
        if crop_flag == False:
            group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
        elif crop_flag == True:
            group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]

        subtracted_volume = np.subtract(volume, group_mean)

    return subtracted_volume


def _statistics_calculator(volume, target):
    """ Training statistics calculator

    This function calculates the MSE, MAE, CEL, Pearson R and P and linear regression W and B for a predicted volume and it's target ground truth.

    Args:
        volume (np.array): Predicted volume
        target (np.array): Ground truth volume
    
    Returns:
        mse (np.float64): The mean squared error between the prediction and the ground truth; The closer to 0, the better
        mae (np.float64): The mean absolut error between the prediction and the ground truth; The closer to 0, the better
        cel (np.float64): The cosine distance between the prediction and the ground truth; The closer to 0, the better
        pearson_r (np.float64): Pearson’s correlation coefficient; The closer to 1, the better
        pearson_p (np.float64): Two-tailed p-value for Pearson’s correlation coefficient; the closer to 0, the better
        spearman_r (np.float64): Spearman correlation coefficient; The closer to 1, the better
        spearman_p (np.float64): Two-tailed p-value for Spearman's correlation coefficient; the closer to 0, the better
        regression_w (np.float64): Slope of the linear regression line; The closer to 1 the better
        regression_b (np.float64): Intersect of the linear regression line; The closer to 0, the better
    """

    x = np.reshape(volume, -1)
    y = np.reshape(target, -1)

    mse = np.square(np.subtract(x,y)).mean()
    mae = np.abs(np.subtract(x,y)).mean()
    cel = np.mean(cosine(x, y))
    pearson_r, pearson_p = pearsonr(x,y)
    spearman_r, spearman_p = spearmanr(x,y)

    x_matrix = np.vstack((np.ones(len(x)), x)).T
    regression_b, regression_w = np.linalg.inv(x_matrix.T.dot(x_matrix)).dot(x_matrix.T).dot(y)

    return mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b


911
912
913
914
915
916
917
918
919
920
921
922
923
def _pearson_correlation(volume, target):
    """Calculate Pearson Correlation Coefficient

    This function calculates the pearson correlation coefficient between a predicted volume and the target volume

    Args:
        volume (np.array): The predicted volume
        target (np.array): The target volume

    Returns:
        r (np.float32): The Pearson Correlation Coefficient
    """

Andrei Roibu's avatar
Andrei Roibu committed
924
925
926
    r = np.sum(np.multiply(np.subtract(volume, volume.mean()), np.subtract(target, target.mean()))) / np.sqrt(np.multiply(
        np.sum(np.power(np.subtract(volume, volume.mean()), 2)), np.sum(np.power(np.subtract(target, target.mean()), 2))))

927
    return r
928