data_evaluation_utils.py 45.1 KB
Newer Older
1
2
3
4
"""Data Evaluation Functions

Description:

5
    This folder contains several functions which, either on their own or included in larger pieces of software, perform data evaluation tasks.
6

7
8
9
10
11
Usage:

    To use content from this folder, import the functions and instantiate them as you wish to use them:

        from utils.data_evaluation_utils import function_name
12

13
14
TODO: Might be worth adding some information on uncertaintiy estimation, later down the line

15
16
17
"""

import os
18
import pickle
19
20
import numpy as np
import torch
21
import logging
22
import h5py
23
import utils.data_utils as data_utils
24
from utils.common_utils import create_folder
Andrei Roibu's avatar
Andrei Roibu committed
25
import pandas as pd
26
from fsl.data.image import Image
27
from fsl.utils.image.roi import roi
Andrei Roibu's avatar
Andrei Roibu committed
28
import itertools
29
30
from scipy.spatial.distance import cosine
from scipy.stats import pearsonr, spearmanr
31
32

log = logging.getLogger(__name__)
33

34
35
36
37
38
39
40
def evaluate_data(trained_model_path,
                     data_directory,
                     mapping_data_file,
                     mapping_targets_file,
                     data_list,
                     prediction_output_path,
                     prediction_output_database_name,
41
                     prediction_output_statistics_name,
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
                     brain_mask_path,
                     dmri_mean_mask_path,
                     rsfmri_mean_mask_path,
                     regression_factors,
                     mean_regression_flag,
                     mean_regression_all_flag, 
                     mean_subtraction_flag,
                     scale_volumes_flag,
                     normalize_flag,
                     minus_one_scaling_flag,
                     negative_flag, 
                     outlier_flag,
                     shrinkage_flag,
                     hard_shrinkage_flag,
                     crop_flag,
                     device=0, 
                     exit_on_error=False,
59
                     output_database_flag=False,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
                     cross_domain_x2x_flag=False,
                     cross_domain_y2y_flag=False,
                     mode='evaluate'):
                     
    """Model Evaluator

    This function generates the rsfMRI arrays for the given inputs

    Args:
        trained_model_path (str): Path to the location of the trained model
        data_directory (str): Path to input data directory
        mapping_data_file (str): Path to the input file
        mapping_targets_file (str): Path to the target file
        data_list (str): Path to a .txt file containing the input files for consideration
        prediction_output_path (str): Output prediction path
        prediction_output_database_name (str): Name of the output database
76
        prediction_output_statistics_name (str): Name of the output statistics database
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        brain_mask_path (str): Path to the MNI brain mask file
        dmri_mean_mask_path (str): Path to the dualreg subject mean mask
        rsfmri_mean_mask_path (str): Path to the summed tract mean mask
        regression_factors (str): Path to the linear regression weights file
        mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
        mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
        scale_volumes_flag (bool): Flag indicating if the volumes should be scaled.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
        device (str/int): Device type used for training (int - GPU id, str- CPU)
        mode (str): Current run mode or phase
        exit_on_error (bool): Flag that triggers the raising of an exception
95
        output_database_flag (bool): Flag indicating if the output maps should be saved to hdf5 database
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets

    Raises:
        FileNotFoundError: Error in reading the provided file!
        Exception: Error code execution!
    """

    log.info(
        "Started Evaluation. Check tensorboard for plots (if a LogWriter is provided)")

    with open(data_list) as data_list_file:
        volumes_to_be_used = data_list_file.read().splitlines()

    # Test if cuda is available and attempt to run on GPU

    cuda_available = torch.cuda.is_available()
    if type(device) == int:
        if cuda_available:
            model = torch.load(trained_model_path)
            torch.cuda.empty_cache()
            model.cuda(device)
        else:
            log.warning(
                "CUDA not available. Switching to CPU. Investigate behaviour!")
            device = 'cpu'

    if (type(device) == str) or not cuda_available:
        model = torch.load(trained_model_path,
                           map_location=torch.device(device))

    model.eval()

    # Create the prediction path folder if this is not available

    create_folder(prediction_output_path)

    # Initiate the evaluation

    log.info("rsfMRI Generation Started")
    if cross_domain_y2y_flag == True:
        # If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
138
139
140
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file, mapping_targets_file=mapping_targets_file)
    elif cross_domain_x2x_flag == True:
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file=mapping_data_file)
141
    else:
142
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file)
143

144
145
146
147
148
149
150
151
    if output_database_flag == True:
        output_database_path = os.path.join(prediction_output_path, prediction_output_database_name)
        if os.path.exists(output_database_path):
            os.remove(output_database_path)
        output_database_handle = h5py.File(output_database_path, 'w')

    output_statistics = {}
    output_statistics_path = os.path.join(prediction_output_path, prediction_output_statistics_name)
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    with torch.no_grad():

        for volume_index, file_path in enumerate(file_paths):
            try:
                print("Mapping Volume {}/{}".format(volume_index+1, len(file_paths)))
                # Generate volume & header

                subject = volumes_to_be_used[volume_index]

                predicted_complete_volume, predicted_volume, header, xform = _generate_volume_map(file_path,
                                                                                                subject,
                                                                                                model,
                                                                                                device,
                                                                                                cuda_available,
                                                                                                brain_mask_path,
                                                                                                dmri_mean_mask_path,
                                                                                                rsfmri_mean_mask_path,
                                                                                                regression_factors,
                                                                                                mean_regression_flag,
                                                                                                mean_regression_all_flag, 
                                                                                                mean_subtraction_flag,
                                                                                                scale_volumes_flag,
                                                                                                normalize_flag,
                                                                                                minus_one_scaling_flag,
                                                                                                negative_flag, 
                                                                                                outlier_flag,
                                                                                                shrinkage_flag,
                                                                                                hard_shrinkage_flag,
                                                                                                crop_flag,
                                                                                                cross_domain_x2x_flag,
                                                                                                cross_domain_y2y_flag)

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                target_volume = _generate_target_volume(file_path,
                                                        subject,
                                                        dmri_mean_mask_path,
                                                        rsfmri_mean_mask_path,
                                                        regression_factors,
                                                        mean_regression_flag,
                                                        mean_regression_all_flag, 
                                                        mean_subtraction_flag,
                                                        crop_flag, 
                                                        cross_domain_x2x_flag
                                                        )

                mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b = _statistics_calculator(predicted_volume, target_volume)
                output_statistics[subject] = [mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b]

                if output_database_flag == True:
                    group = output_database_handle.create_group(subject)
                    group.create_dataset('predicted_complete_volume', data=predicted_complete_volume)
                    group.create_dataset('predicted_volume', data=predicted_volume)
                    group.create_dataset('header', data=header)
                    group.create_dataset('xform', data=xform)
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

                log.info("Processed: " + volumes_to_be_used[volume_index] + " " + str(
                    volume_index + 1) + " out of " + str(len(volumes_to_be_used)))

                print("Mapped Volumes saved in: ", prediction_output_path)

            except FileNotFoundError as exception_expression:
                log.error("Error in reading the provided file!")
                log.exception(exception_expression)
                if exit_on_error:
                    raise(exception_expression)

            except Exception as exception_expression:
                log.error("Error code execution!")
                log.exception(exception_expression)
                if exit_on_error:
                    raise(exception_expression)

224
        output_statistics_df = pd.DataFrame.from_dict(output_statistics, orient='index', columns=['mse', 'mae', 'cel', 'pearson_r', 'pearson_p', 'spearman_r', 'spearman_p', 'regression_w', 'regression_b'])     
225
        output_statistics_df.to_pickle(output_statistics)
226

227
228
    log.info("Output Data Generation Complete")

229
230
    if output_database_flag == True:
        output_database_handle.close()
231

232

233
def evaluate_mapping(trained_model_path,
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
234
235
                     data_directory,
                     mapping_data_file,
236
                     mapping_targets_file,
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
237
238
                     data_list,
                     prediction_output_path,
239
                     brain_mask_path,
Andrei Roibu's avatar
Andrei Roibu committed
240
241
242
                     dmri_mean_mask_path,
                     rsfmri_mean_mask_path,
                     regression_factors,
243
244
245
246
247
                     mean_regression_flag,
                     mean_regression_all_flag, 
                     mean_subtraction_flag,
                     scale_volumes_flag,
                     normalize_flag,
Andrei Roibu's avatar
Andrei Roibu committed
248
                     minus_one_scaling_flag,
249
250
251
252
253
254
255
                     negative_flag, 
                     outlier_flag,
                     shrinkage_flag,
                     hard_shrinkage_flag,
                     crop_flag,
                     device=0, 
                     exit_on_error=False,
256
257
                     cross_domain_x2x_flag=False,
                     cross_domain_y2y_flag=False,
258
                     mode='evaluate'):
259
260
261
262
263
264
265
266
    """Model Evaluator

    This function generates the rsfMRI map for an input running on on a single axis or path

    Args:
        trained_model_path (str): Path to the location of the trained model
        data_directory (str): Path to input data directory
        mapping_data_file (str): Path to the input file
267
        mapping_targets_file (str): Path to the target file
268
269
        data_list (str): Path to a .txt file containing the input files for consideration
        prediction_output_path (str): Output prediction path
270
        brain_mask_path (str): Path to the MNI brain mask file
Andrei Roibu's avatar
Andrei Roibu committed
271
272
273
        dmri_mean_mask_path (str): Path to the dualreg subject mean mask
        rsfmri_mean_mask_path (str): Path to the summed tract mean mask
        regression_factors (str): Path to the linear regression weights file
274
275
276
277
278
        mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
        mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
        scale_volumes_flag (bool): Flag indicating if the volumes should be scaled.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
Andrei Roibu's avatar
Andrei Roibu committed
279
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
280
281
282
283
284
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
285
286
287
        device (str/int): Device type used for training (int - GPU id, str- CPU)
        mode (str): Current run mode or phase
        exit_on_error (bool): Flag that triggers the raising of an exception
288
289
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

    Raises:
        FileNotFoundError: Error in reading the provided file!
        Exception: Error code execution!
    """

    log.info(
        "Started Evaluation. Check tensorboard for plots (if a LogWriter is provided)")

    with open(data_list) as data_list_file:
        volumes_to_be_used = data_list_file.read().splitlines()

    # Test if cuda is available and attempt to run on GPU

    cuda_available = torch.cuda.is_available()
    if type(device) == int:
        if cuda_available:
            model = torch.load(trained_model_path)
            torch.cuda.empty_cache()
            model.cuda(device)
        else:
            log.warning(
                "CUDA not available. Switching to CPU. Investigate behaviour!")
            device = 'cpu'

    if (type(device) == str) or not cuda_available:
        model = torch.load(trained_model_path,
                           map_location=torch.device(device))

    model.eval()

    # Create the prediction path folder if this is not available

323
    create_folder(prediction_output_path)
324
325
326
327

    # Initiate the evaluation

    log.info("rsfMRI Generation Started")
328
329
330
331
332
    if cross_domain_y2y_flag == True:
        # If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file)
    else:
        file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file)
333
334
335
336
337

    with torch.no_grad():

        for volume_index, file_path in enumerate(file_paths):
            try:
338
                print("Mapping Volume {}/{}".format(volume_index+1, len(file_paths)))
339
                # Generate volume & header
Andrei Roibu's avatar
Andrei Roibu committed
340
341
342

                subject = volumes_to_be_used[volume_index]

343
344
345
346
347
348
349
350
351
352
353
354
355
356
                predicted_complete_volume, predicted_volume, header, xform = _generate_volume_map(file_path,
                                                                                                subject,
                                                                                                model,
                                                                                                device,
                                                                                                cuda_available,
                                                                                                brain_mask_path,
                                                                                                dmri_mean_mask_path,
                                                                                                rsfmri_mean_mask_path,
                                                                                                regression_factors,
                                                                                                mean_regression_flag,
                                                                                                mean_regression_all_flag, 
                                                                                                mean_subtraction_flag,
                                                                                                scale_volumes_flag,
                                                                                                normalize_flag,
Andrei Roibu's avatar
Andrei Roibu committed
357
                                                                                                minus_one_scaling_flag,
358
359
360
361
                                                                                                negative_flag, 
                                                                                                outlier_flag,
                                                                                                shrinkage_flag,
                                                                                                hard_shrinkage_flag,
362
363
364
                                                                                                crop_flag,
                                                                                                cross_domain_x2x_flag,
                                                                                                cross_domain_y2y_flag)
365
366

                if crop_flag == False:
367
368
369
370
                    if cross_domain_y2y_flag == True:
                        output_nifti_image = Image(predicted_volume, header=header)
                    else:
                        output_nifti_image = Image(predicted_volume, header=header, xform=xform)
371
372
373
                elif crop_flag == True:
                    output_nifti_image = Image(predicted_volume, header=header)
                    output_nifti_image = roi(output_nifti_image, ((-9,82),(-10,99),(0,91)))
374
375
376
377
378
379
380
381
382

                output_nifti_path = os.path.join(
                    prediction_output_path, volumes_to_be_used[volume_index])

                if '.nii' not in output_nifti_path:
                    output_nifti_path += '.nii.gz'

                output_nifti_image.save(output_nifti_path)

383
384
                if mean_regression_flag == True:
                    if crop_flag == False:
385
386
387
388
                        if cross_domain_y2y_flag == True:
                            output_nifti_image = Image(predicted_complete_volume, header=header)
                        else:
                            output_complete_nifti_image = Image(predicted_complete_volume, header=header, xform=xform)
389
390
391
                    elif crop_flag == True:
                        output_complete_nifti_image = Image(predicted_complete_volume, header=header)
                        output_complete_nifti_image = roi(output_complete_nifti_image, ((-9,82),(-10,99),(0,91)))
392

393
394
                    output_complete_nifti_path = os.path.join(
                        prediction_output_path, volumes_to_be_used[volume_index]) + '_complete'
395
396
397
398

                    if '.nii' not in output_complete_nifti_path:
                        output_complete_nifti_path += '.nii.gz'

399
400
                    output_complete_nifti_image.save(
                        output_complete_nifti_path)
401

402
403
404
                log.info("Processed: " + volumes_to_be_used[volume_index] + " " + str(
                    volume_index + 1) + " out of " + str(len(volumes_to_be_used)))

405
406
                print("Mapped Volumes saved in: ", prediction_output_path)

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            except FileNotFoundError as exception_expression:
                log.error("Error in reading the provided file!")
                log.exception(exception_expression)
                if exit_on_error:
                    raise(exception_expression)

            except Exception as exception_expression:
                log.error("Error code execution!")
                log.exception(exception_expression)
                if exit_on_error:
                    raise(exception_expression)

    log.info("rsfMRI Generation Complete")


Andrei Roibu's avatar
Andrei Roibu committed
422
423
424
425
426
427
428
429
430
def _generate_volume_map(file_path,
                         subject,
                         model,
                         device,
                         cuda_available,
                         brain_mask_path,
                         dmri_mean_mask_path,
                         rsfmri_mean_mask_path,
                         regression_factors,
431
432
433
434
435
                         mean_regression_flag,
                         mean_regression_all_flag, 
                         mean_subtraction_flag,
                         scale_volumes_flag,
                         normalize_flag,
Andrei Roibu's avatar
Andrei Roibu committed
436
                         minus_one_scaling_flag,
437
438
439
440
441
                         negative_flag, 
                         outlier_flag,
                         shrinkage_flag,
                         hard_shrinkage_flag,
                         crop_flag,
442
443
                         cross_domain_x2x_flag,
                         cross_domain_y2y_flag
444
                         ):
445
    """Output Volume Generator
446
447
448
449
450

    This function uses the trained model to generate a new volume

    Args:
        file_path (str): Path to the desired file
Andrei Roibu's avatar
Andrei Roibu committed
451
        subject (str): Subject ID of the subject volume to be regressed
452
453
454
        model (class): BrainMapper model class
        device (str/int): Device type used for training (int - GPU id, str- CPU)
        cuda_available (bool): Flag indicating if a cuda-enabled GPU is present
455
        brain_mask_path (str): Path to the MNI brain mask file
Andrei Roibu's avatar
Andrei Roibu committed
456
457
458
        dmri_mean_mask_path (str): Path to the group mean volume
        rsfmri_mean_mask_path (str): Path to the dualreg subject mean mask
        regression_factors (str): Path to the linear regression weights file
459
460
461
462
463
        mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
        mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
        scale_volumes_flag (bool): Flag indicating if the volumes should be scaled.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
Andrei Roibu's avatar
Andrei Roibu committed
464
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
465
466
467
468
469
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
470
471
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
472
473
474
475
476
477

    Returns
        predicted_volume (np.array): Array containing the information regarding the generated volume
        header (class): 'nibabel.nifti1.Nifti1Header' class object, containing volume metadata
    """

478
    volume, header, xform = data_utils.load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag)
479

480
481
    if mean_regression_flag == True:
        if mean_regression_all_flag == True:
482
            volume = _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag)
Andrei Roibu's avatar
Andrei Roibu committed
483
            scaling_parameters = [-0.0626, 0.1146, -14.18, 16.9475]
484
        else:
Andrei Roibu's avatar
Andrei Roibu committed
485
            scaling_parameters = [0.0, 0.2, -14.18, 16.9475]
486
487
    elif mean_subtraction_flag == True:
        scaling_parameters = [0.0, 0.2, 0.0, 10.0]
Andrei Roibu's avatar
Andrei Roibu committed
488

489
490
    print('volume range:', np.min(volume), np.max(volume))

491
    if scale_volumes_flag == True:
492
        volume = _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag)
Andrei Roibu's avatar
Andrei Roibu committed
493

494
    if len(volume.shape) == 5:
495
496
497
498
499
500
        volume = volume
    else:
        volume = volume[np.newaxis, np.newaxis, :, :, :]

    volume = torch.tensor(volume).type(torch.FloatTensor)

501
502
    if cuda_available and (type(device) == int):
        volume = volume.cuda(device)
503

504
    output = model(volume)
505
506
    output = (output.cpu().numpy()).astype('float32')
    output = np.squeeze(output)
Andrei Roibu's avatar
Andrei Roibu committed
507

508
509
    print('output range:', np.min(output), np.max(output))

510
    output = _rescale_output(output, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_x2x_flag)
511

512
513
    print('output rescaled:', np.min(output), np.max(output))

514
515
516
517
518
519
    if crop_flag == False:
        MNI152_T1_2mm_brain_mask = Image(brain_mask_path).data
    elif crop_flag == True:
        MNI152_T1_2mm_brain_mask = roi(Image(brain_mask_path),((9,81),(10,100),(0,77))).data

    if mean_regression_flag == True or mean_subtraction_flag == True:
520

521
522
523
524
525
526
527
528
529
530
        if cross_domain_x2x_flag == True:
            if crop_flag == False:
                mean_mask = Image(dmri_mean_mask_path).data
            elif crop_flag == True:
                mean_mask = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
        else:
            if crop_flag == False:
                mean_mask = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
            elif crop_flag == True:
                mean_mask = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
531
532

        if mean_regression_flag == True:
533
534
535
536
            if cross_domain_x2x_flag == True:
                weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
            else:
                weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
537
            predicted_complete_volume = np.add(output, np.multiply(weight, mean_mask))
538
539

        if mean_subtraction_flag == True:
540
            predicted_complete_volume = np.add(output, mean_mask)
541

Andrei Roibu's avatar
Andrei Roibu committed
542
543
        print('predicted_complete_volume', np.min(
            predicted_complete_volume), np.max(predicted_complete_volume))
544

Andrei Roibu's avatar
Andrei Roibu committed
545
546
        predicted_complete_volume = np.multiply(
            predicted_complete_volume, MNI152_T1_2mm_brain_mask)
547

Andrei Roibu's avatar
Andrei Roibu committed
548
549
        print('predicted_complete_volume masked:', np.min(
            predicted_complete_volume), np.max(predicted_complete_volume))
550

551
552
    else:
        predicted_complete_volume = None
553

554
    predicted_volume = np.multiply(output, MNI152_T1_2mm_brain_mask)
555

Andrei Roibu's avatar
Andrei Roibu committed
556
557
    print('predicted_volume masked:', np.min(
        predicted_volume), np.max(predicted_volume))
558

559
    return predicted_complete_volume, predicted_volume, header, xform
560
561


562
def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag):
563
564
565
566
567
    """Input Scaling

    This function reads the scaling factors from the saved file and then scales the data.

    Args:
568
569
570
        volume (np.array): Numpy array representing the un-scalled volume. 
        scaling_parameters (list): List of scaling parameters.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
Andrei Roibu's avatar
Andrei Roibu committed
571
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
572
573
574
575
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
576
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
577
578
579
580
581

    Returns:
        scaled_volume (np.array): Scaled volume
    """

582
583
584
585
    if cross_domain_y2y_flag == True:
        _, _, min_value, max_value = scaling_parameters
    else:
        min_value, max_value, _, _ = scaling_parameters
586
587

    if shrinkage_flag == True:
588
589
590
591
        if cross_domain_y2y_flag == True:
            lambd = 3.0
        else:
            lambd = 0.003 # Hard coded, equivalent to tht 1p and 99p values across the whole population in UKBB
592
593
594
595
596
597
598

        if hard_shrinkage_flag == True:
            volume = _hard_shrinkage(volume, lambd)
        elif hard_shrinkage_flag == False:
            volume = _soft_shrinkage(volume, lambd)
            min_value += lambd
            max_value -= lambd
599

600
601
602
603
604
605
606
    if negative_flag == True:
        volume[volume < 0.0] = 0.0
        min_value = 0.0

    if outlier_flag == True:
        volume[volume > max_value] = max_value
        volume[volume < min_value] = min_value
607

608
609
610
    if normalize_flag == True:
        # Normalization to [0, 1]
        scaled_volume = np.divide(np.subtract(volume, min_value), np.subtract(max_value, min_value))
Andrei Roibu's avatar
Andrei Roibu committed
611
    elif minus_one_scaling_flag == True:
612
613
        # Scaling between [-1, 1]
        scaled_volume = np.add(-1.0, np.multiply(2.0, np.divide(np.subtract(volume, min_value), np.subtract(max_value, min_value))))
Andrei Roibu's avatar
Andrei Roibu committed
614
615
    # Else, no scaling occus, but the other flags can still hold true if the scaling flag is true! 

616
617
618
    return scaled_volume


619
def _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag):
Andrei Roibu's avatar
Andrei Roibu committed
620
621
622
623
624
625
626
627
628
629
    """ Inputn Regression

    This function regresse the group mean from the input volume using the saved regression weights.

    TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.

    Args:
        volume (np.array): Unregressed volume
        subject (str): Subject ID of the subject volume to be regressed
        dmri_mean_mask_path (str): Path to the group mean volume
630
        rsfmri_mean_mask_path (str): Path to the target group mean volume
Andrei Roibu's avatar
Andrei Roibu committed
631
        regression_factors (str): Path to the linear regression weights file
632
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
633
        cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Andrei Roibu's avatar
Andrei Roibu committed
634
635
636
637
638
639

    Returns:
        regressed_volume (np.array): Linear regressed volume

    """

640
641
642
643
644
645
646
647
648
649
650
651
    if cross_domain_y2y_flag == True:
        weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI'] 
        if crop_flag == False:
            group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
        elif crop_flag == True:
            group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
    else:
        weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
        if crop_flag == False:
            group_mean = Image(dmri_mean_mask_path).data
        elif crop_flag == True:
            group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
652

653
    regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
Andrei Roibu's avatar
Andrei Roibu committed
654
655
656

    return regressed_volume

Andrei Roibu's avatar
Andrei Roibu committed
657

658
def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_x2x_flag):
659
660
661
662
663
664
    """Output Rescaling

    This function reads the scaling factors from the saved file and then scales the data.

    Args:
        volume (np.array): Unscalled volume
665
666
        scaling_parameters (list): List of scaling parameters.
        normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
Andrei Roibu's avatar
Andrei Roibu committed
667
        minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
668
669
670
        negative_flag (bool): Flag indicating if all the negative values should be 0-ed. 
        shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
        hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. 
671
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
672
673
674
675
676

    Returns:
        rescaled_volume (np.array): Rescaled volume
    """

677
678
679
680
    if cross_domain_x2x_flag == True:
        min_value, max_value, _, _ = scaling_parameters
    else:
        _, _, min_value, max_value = scaling_parameters
681
682

    if shrinkage_flag == True:
683
684
685
686
687
        if cross_domain_x2x_flag == True:
            lambd = 0.003
        else:
            lambd = 3.0
        
688
689
690
691
692
        if hard_shrinkage_flag == True:
            pass
        elif hard_shrinkage_flag == False:
            min_value += lambd
            max_value -= lambd
693

694
695
    if negative_flag == True:
        min_value = 0.0
696

697
698
699
    if normalize_flag == True:
        # Normalization to [0, 1]
        rescaled_volume = np.add(np.multiply(volume, np.subtract(max_value, min_value)), min_value)
Andrei Roibu's avatar
Andrei Roibu committed
700
    elif minus_one_scaling_flag == True:
701
702
        # Scaling between [-1, 1]
        rescaled_volume = np.add(np.multiply(np.divide(np.add(volume, 1), 2), np.subtract(max_value, min_value)), min_value)
Andrei Roibu's avatar
Andrei Roibu committed
703
704
    # Else, no rescaling occus, but the other flags can still hold true if the scaling flag is true! 

705
706
707
    return rescaled_volume


708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
def _hard_shrinkage(volume, lambd):
    """ Hard Shrinkage

    This function performs a hard shrinkage on the volumes.
    volume = { x , x > lambd | x < -lambd
                0 , x e [-lambd, lambd]
                }

    Args:
        volume (np.array): Unshrunken volume
        lambd (float): Threshold parameter
    
    Returns:
        volume (np.array) : Hard shrunk volume
    """

    volume[np.where(np.logical_and(volume>-lambd, volume<lambd))] = 0

    return volume

728

729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
def _soft_shrinkage(volume, lambd):
    """ Soft Shrinkage

    This function performs a soft shrinkage on the volumes.
    volume = { x + lambd , x < -lambd
                0         , x e [-lambd, lambd]
                x - lambd , x > lambd
                }

    Args:
        volume (np.array): Unshrunken volume
        lambd (float): Threshold parameter
    
    Returns:
        volume (np.array) : Soft shrunk volume
    """

    volume[np.where(np.logical_and(volume>=-lambd, volume<=lambd))] = 0.0
    volume[volume < -lambd] = volume[volume < -lambd] + lambd
    volume[volume > lambd] = volume[volume > lambd] - lambd

    return volume


753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
def _generate_target_volume(file_path,
                            subject,
                            dmri_mean_mask_path,
                            rsfmri_mean_mask_path,
                            regression_factors,
                            mean_regression_flag,
                            mean_regression_all_flag, 
                            mean_subtraction_flag,
                            crop_flag, 
                            cross_domain_x2x_flag
                            ):
    """Target Volume Generator

    This function loads and preprocesses a target volume for comparing with the network predicted volumes

    Args:
        file_path (str): Path to the desired file
        subject (str): Subject ID of the subject volume to be regressed
        dmri_mean_mask_path (str): Path to the group mean volume
        rsfmri_mean_mask_path (str): Path to the dualreg subject mean mask
        regression_factors (str): Path to the linear regression weights file
        mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
        mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs

    Returns:
        volume (np.array): Array containing the information regarding the target volume
    """

    volume = data_utils.load_and_preprocess_targets(file_path, crop_flag, cross_domain_x2x_flag)

    if mean_regression_flag == True:
        volume = _regress_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_x2x_flag, mean_regression_all_flag)
    elif mean_subtraction_flag == True:
        volume = _subtract_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, crop_flag, cross_domain_x2x_flag)

    return volume


def _regress_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_x2x_flag, mean_regression_all_flag):
    """ Target Regression

    This function regresse the group mean from the target volume using the saved regression weights.

    TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.

    Args:
        volume (np.array): Unregressed volume
        subject (str): Subject ID of the subject volume to be regressed
        dmri_mean_mask_path (str): Path to the group mean volume
        rsfmri_mean_mask_path (str): Path to the target group mean volume
        regression_factors (str): Path to the linear regression weights file
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets
        mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.

    Returns:
        regressed_volume (np.array): Linear regressed volume

    """

    if cross_domain_x2x_flag == True:
        if mean_regression_all_flag == True:
            weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
            if crop_flag == False:
                group_mean = Image(dmri_mean_mask_path).data
            elif crop_flag == True:
                group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
            
            regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
        else:
            regressed_volume = volume
    else:
        weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI'] 
        if crop_flag == False:
            group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
        elif crop_flag == True:
            group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]

        regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))

    return regressed_volume


def _subtract_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, crop_flag, cross_domain_x2x_flag):
    """ Target Subtraction

    This function subtracts the group mean from the target volume using the saved regression weights.

    TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.

    Args:
        volume (np.array): Unregressed volume
        subject (str): Subject ID of the subject volume to be regressed
        dmri_mean_mask_path (str): Path to the group mean volume
        rsfmri_mean_mask_path (str): Path to the target group mean volume
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
        cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets

    Returns:
        regressed_volume (np.array): Linear regressed volume

    """

    if cross_domain_x2x_flag == True:
        subtracted_volume = volume
    else:
        if crop_flag == False:
            group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
        elif crop_flag == True:
            group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]

        subtracted_volume = np.subtract(volume, group_mean)

    return subtracted_volume


def _statistics_calculator(volume, target):
    """ Training statistics calculator

    This function calculates the MSE, MAE, CEL, Pearson R and P and linear regression W and B for a predicted volume and it's target ground truth.

    Args:
        volume (np.array): Predicted volume
        target (np.array): Ground truth volume
    
    Returns:
        mse (np.float64): The mean squared error between the prediction and the ground truth; The closer to 0, the better
        mae (np.float64): The mean absolut error between the prediction and the ground truth; The closer to 0, the better
        cel (np.float64): The cosine distance between the prediction and the ground truth; The closer to 0, the better
        pearson_r (np.float64): Pearson’s correlation coefficient; The closer to 1, the better
        pearson_p (np.float64): Two-tailed p-value for Pearson’s correlation coefficient; the closer to 0, the better
        spearman_r (np.float64): Spearman correlation coefficient; The closer to 1, the better
        spearman_p (np.float64): Two-tailed p-value for Spearman's correlation coefficient; the closer to 0, the better
        regression_w (np.float64): Slope of the linear regression line; The closer to 1 the better
        regression_b (np.float64): Intersect of the linear regression line; The closer to 0, the better
    """

    x = np.reshape(volume, -1)
    y = np.reshape(target, -1)

    mse = np.square(np.subtract(x,y)).mean()
    mae = np.abs(np.subtract(x,y)).mean()
    cel = np.mean(cosine(x, y))
    pearson_r, pearson_p = pearsonr(x,y)
    spearman_r, spearman_p = spearmanr(x,y)

    x_matrix = np.vstack((np.ones(len(x)), x)).T
    regression_b, regression_w = np.linalg.inv(x_matrix.T.dot(x_matrix)).dot(x_matrix.T).dot(y)

    return mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b


908
909
910
911
912
913
914
915
916
917
918
919
920
def _pearson_correlation(volume, target):
    """Calculate Pearson Correlation Coefficient

    This function calculates the pearson correlation coefficient between a predicted volume and the target volume

    Args:
        volume (np.array): The predicted volume
        target (np.array): The target volume

    Returns:
        r (np.float32): The Pearson Correlation Coefficient
    """

Andrei Roibu's avatar
Andrei Roibu committed
921
922
923
    r = np.sum(np.multiply(np.subtract(volume, volume.mean()), np.subtract(target, target.mean()))) / np.sqrt(np.multiply(
        np.sum(np.power(np.subtract(volume, volume.mean()), 2)), np.sum(np.power(np.subtract(target, target.mean()), 2))))

924
    return r
925