Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
8cbb592b
Commit
8cbb592b
authored
Aug 21, 2020
by
Andrei Roibu
Browse files
added cross-domain x2x y2y x2y options to mapping eval
parent
1c12a511
Changes
3
Hide whitespace changes
Inline
Side-by-side
settings_evaluation.ini
View file @
8cbb592b
...
...
@@ -3,6 +3,7 @@ trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path
=
"VA2-1_predictions"
data_directory
=
"/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file
=
"dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
mapping_targets_file
=
"fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_list
=
"datasets/test.txt"
brain_mask_path
=
"utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path
=
"utils/mean_dr_stage2.nii.gz"
...
...
@@ -21,4 +22,6 @@ shrinkage_flag = False
hard_shrinkage_flag
=
False
crop_flag
=
True
device
=
0
exit_on_error
=
True
\ No newline at end of file
exit_on_error
=
True
cross_domain_x2x_flag
=
False
cross_domain_y2y_flag
=
False
\ No newline at end of file
utils/data_evaluation_utils.py
View file @
8cbb592b
...
...
@@ -31,6 +31,7 @@ log = logging.getLogger(__name__)
def
evaluate_mapping
(
trained_model_path
,
data_directory
,
mapping_data_file
,
mapping_targets_file
,
data_list
,
prediction_output_path
,
brain_mask_path
,
...
...
@@ -50,6 +51,8 @@ def evaluate_mapping(trained_model_path,
crop_flag
,
device
=
0
,
exit_on_error
=
False
,
cross_domain_x2x_flag
=
False
,
cross_domain_y2y_flag
=
False
,
mode
=
'evaluate'
):
"""Model Evaluator
...
...
@@ -59,6 +62,7 @@ def evaluate_mapping(trained_model_path,
trained_model_path (str): Path to the location of the trained model
data_directory (str): Path to input data directory
mapping_data_file (str): Path to the input file
mapping_targets_file (str): Path to the target file
data_list (str): Path to a .txt file containing the input files for consideration
prediction_output_path (str): Output prediction path
brain_mask_path (str): Path to the MNI brain mask file
...
...
@@ -79,6 +83,8 @@ def evaluate_mapping(trained_model_path,
device (str/int): Device type used for training (int - GPU id, str- CPU)
mode (str): Current run mode or phase
exit_on_error (bool): Flag that triggers the raising of an exception
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Raises:
FileNotFoundError: Error in reading the provided file!
...
...
@@ -117,8 +123,11 @@ def evaluate_mapping(trained_model_path,
# Initiate the evaluation
log
.
info
(
"rsfMRI Generation Started"
)
file_paths
,
volumes_to_be_used
=
data_utils
.
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
)
if
cross_domain_y2y_flag
==
True
:
# If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
file_paths
,
volumes_to_be_used
=
data_utils
.
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
=
mapping_targets_file
)
else
:
file_paths
,
volumes_to_be_used
=
data_utils
.
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
)
with
torch
.
no_grad
():
...
...
@@ -148,10 +157,15 @@ def evaluate_mapping(trained_model_path,
outlier_flag
,
shrinkage_flag
,
hard_shrinkage_flag
,
crop_flag
)
crop_flag
,
cross_domain_x2x_flag
,
cross_domain_y2y_flag
)
if
crop_flag
==
False
:
output_nifti_image
=
Image
(
predicted_volume
,
header
=
header
,
xform
=
xform
)
if
cross_domain_y2y_flag
==
True
:
output_nifti_image
=
Image
(
predicted_volume
,
header
=
header
)
else
:
output_nifti_image
=
Image
(
predicted_volume
,
header
=
header
,
xform
=
xform
)
elif
crop_flag
==
True
:
output_nifti_image
=
Image
(
predicted_volume
,
header
=
header
)
output_nifti_image
=
roi
(
output_nifti_image
,
((
-
9
,
82
),(
-
10
,
99
),(
0
,
91
)))
...
...
@@ -166,7 +180,10 @@ def evaluate_mapping(trained_model_path,
if
mean_regression_flag
==
True
:
if
crop_flag
==
False
:
output_complete_nifti_image
=
Image
(
predicted_complete_volume
,
header
=
header
,
xform
=
xform
)
if
cross_domain_y2y_flag
==
True
:
output_nifti_image
=
Image
(
predicted_complete_volume
,
header
=
header
)
else
:
output_complete_nifti_image
=
Image
(
predicted_complete_volume
,
header
=
header
,
xform
=
xform
)
elif
crop_flag
==
True
:
output_complete_nifti_image
=
Image
(
predicted_complete_volume
,
header
=
header
)
output_complete_nifti_image
=
roi
(
output_complete_nifti_image
,
((
-
9
,
82
),(
-
10
,
99
),(
0
,
91
)))
...
...
@@ -220,6 +237,8 @@ def _generate_volume_map(file_path,
shrinkage_flag
,
hard_shrinkage_flag
,
crop_flag
,
cross_domain_x2x_flag
,
cross_domain_y2y_flag
):
"""rsfMRI Volume Generator
...
...
@@ -246,17 +265,19 @@ def _generate_volume_map(file_path,
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns
predicted_volume (np.array): Array containing the information regarding the generated volume
header (class): 'nibabel.nifti1.Nifti1Header' class object, containing volume metadata
"""
volume
,
header
,
xform
=
data_utils
.
load_and_preprocess_evaluation
(
file_path
,
crop_flag
)
volume
,
header
,
xform
=
data_utils
.
load_and_preprocess_evaluation
(
file_path
,
crop_flag
,
cross_domain_y2y_flag
)
if
mean_regression_flag
==
True
:
if
mean_regression_all_flag
==
True
:
volume
=
_regress_input
(
volume
,
subject
,
dmri_mean_mask_path
,
r
egression_factors
,
crop
_flag
)
volume
=
_regress_input
(
volume
,
subject
,
dmri_mean_mask_path
,
r
sfmri_mean_mask_path
,
regression_factors
,
crop_flag
,
cross_domain_y2y
_flag
)
scaling_parameters
=
[
-
0.0626
,
0.1146
,
-
14.18
,
16.9475
]
else
:
scaling_parameters
=
[
0.0
,
0.2
,
-
14.18
,
16.9475
]
...
...
@@ -266,7 +287,7 @@ def _generate_volume_map(file_path,
print
(
'volume range:'
,
np
.
min
(
volume
),
np
.
max
(
volume
))
if
scale_volumes_flag
==
True
:
volume
=
_scale_input
(
volume
,
scaling_parameters
,
normalize_flag
,
minus_one_scaling_flag
,
negative_flag
,
outlier_flag
,
shrinkage_flag
,
hard_shrinkage_flag
)
volume
=
_scale_input
(
volume
,
scaling_parameters
,
normalize_flag
,
minus_one_scaling_flag
,
negative_flag
,
outlier_flag
,
shrinkage_flag
,
hard_shrinkage_flag
,
cross_domain_y2y_flag
)
if
len
(
volume
.
shape
)
==
5
:
volume
=
volume
...
...
@@ -284,7 +305,7 @@ def _generate_volume_map(file_path,
print
(
'output range:'
,
np
.
min
(
output
),
np
.
max
(
output
))
output
=
_rescale_output
(
output
,
scaling_parameters
,
normalize_flag
,
minus_one_scaling_flag
,
negative_flag
,
shrinkage_flag
,
hard_shrinkage_flag
)
output
=
_rescale_output
(
output
,
scaling_parameters
,
normalize_flag
,
minus_one_scaling_flag
,
negative_flag
,
shrinkage_flag
,
hard_shrinkage_flag
,
cross_domain_x2x_flag
)
print
(
'output rescaled:'
,
np
.
min
(
output
),
np
.
max
(
output
))
...
...
@@ -295,13 +316,22 @@ def _generate_volume_map(file_path,
if
mean_regression_flag
==
True
or
mean_subtraction_flag
==
True
:
if
crop_flag
==
False
:
mean_mask
=
Image
(
rsfmri_mean_mask_path
).
data
[:,
:,
:,
0
]
elif
crop_flag
==
True
:
mean_mask
=
roi
(
Image
(
rsfmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
[:,
:,
:,
0
]
if
cross_domain_x2x_flag
==
True
:
if
crop_flag
==
False
:
mean_mask
=
Image
(
dmri_mean_mask_path
).
data
elif
crop_flag
==
True
:
mean_mask
=
roi
(
Image
(
dmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
else
:
if
crop_flag
==
False
:
mean_mask
=
Image
(
rsfmri_mean_mask_path
).
data
[:,
:,
:,
0
]
elif
crop_flag
==
True
:
mean_mask
=
roi
(
Image
(
rsfmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
[:,
:,
:,
0
]
if
mean_regression_flag
==
True
:
weight
=
pd
.
read_pickle
(
regression_factors
).
loc
[
subject
][
'w_rsfMRI'
]
if
cross_domain_x2x_flag
==
True
:
weight
=
pd
.
read_pickle
(
regression_factors
).
loc
[
subject
][
'w_dMRI'
]
else
:
weight
=
pd
.
read_pickle
(
regression_factors
).
loc
[
subject
][
'w_rsfMRI'
]
predicted_complete_volume
=
np
.
add
(
output
,
np
.
multiply
(
weight
,
mean_mask
))
if
mean_subtraction_flag
==
True
:
...
...
@@ -327,7 +357,7 @@ def _generate_volume_map(file_path,
return
predicted_complete_volume
,
predicted_volume
,
header
,
xform
def
_scale_input
(
volume
,
scaling_parameters
,
normalize_flag
,
minus_one_scaling_flag
,
negative_flag
,
outlier_flag
,
shrinkage_flag
,
hard_shrinkage_flag
):
def
_scale_input
(
volume
,
scaling_parameters
,
normalize_flag
,
minus_one_scaling_flag
,
negative_flag
,
outlier_flag
,
shrinkage_flag
,
hard_shrinkage_flag
,
cross_domain_y2y_flag
):
"""Input Scaling
This function reads the scaling factors from the saved file and then scales the data.
...
...
@@ -341,15 +371,22 @@ def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_f
outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
scaled_volume (np.array): Scaled volume
"""
min_value
,
max_value
,
_
,
_
=
scaling_parameters
if
cross_domain_y2y_flag
==
True
:
_
,
_
,
min_value
,
max_value
=
scaling_parameters
else
:
min_value
,
max_value
,
_
,
_
=
scaling_parameters
if
shrinkage_flag
==
True
:
lambd
=
0.003
# Hard coded, equivalent to tht 1p and 99p values across the whole population in UKBB
if
cross_domain_y2y_flag
==
True
:
lambd
=
3.0
else
:
lambd
=
0.003
# Hard coded, equivalent to tht 1p and 99p values across the whole population in UKBB
if
hard_shrinkage_flag
==
True
:
volume
=
_hard_shrinkage
(
volume
,
lambd
)
...
...
@@ -377,7 +414,7 @@ def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_f
return
scaled_volume
def
_regress_input
(
volume
,
subject
,
dmri_mean_mask_path
,
r
egression_factors
,
crop
_flag
):
def
_regress_input
(
volume
,
subject
,
dmri_mean_mask_path
,
r
sfmri_mean_mask_path
,
regression_factors
,
crop_flag
,
cross_domain_y2y
_flag
):
""" Inputn Regression
This function regresse the group mean from the input volume using the saved regression weights.
...
...
@@ -388,26 +425,35 @@ def _regress_input(volume, subject, dmri_mean_mask_path, regression_factors, cro
volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
regression_factors (str): Path to the linear regression weights file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
regressed_volume (np.array): Linear regressed volume
"""
weight
=
pd
.
read_pickle
(
regression_factors
).
loc
[
subject
][
'w_dMRI'
]
if
crop_flag
==
False
:
group_mean
=
Image
(
dmri_mean_mask_path
).
data
elif
crop_flag
==
True
:
group_mean
=
roi
(
Image
(
dmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
if
cross_domain_y2y_flag
==
True
:
weight
=
pd
.
read_pickle
(
regression_factors
).
loc
[
subject
][
'w_rsfMRI'
]
if
crop_flag
==
False
:
group_mean
=
Image
(
rsfmri_mean_mask_path
).
data
[:,
:,
:,
0
]
elif
crop_flag
==
True
:
group_mean
=
roi
(
Image
(
rsfmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
[:,
:,
:,
0
]
else
:
weight
=
pd
.
read_pickle
(
regression_factors
).
loc
[
subject
][
'w_dMRI'
]
if
crop_flag
==
False
:
group_mean
=
Image
(
dmri_mean_mask_path
).
data
elif
crop_flag
==
True
:
group_mean
=
roi
(
Image
(
dmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
regressed_volume
=
np
.
subtract
(
volume
,
np
.
multiply
(
weight
,
group_mean
))
regressed_volume
=
np
.
subtract
(
volume
,
np
.
multiply
(
weight
,
group_mean
))
return
regressed_volume
def
_rescale_output
(
volume
,
scaling_parameters
,
normalize_flag
,
minus_one_scaling_flag
,
negative_flag
,
shrinkage_flag
,
hard_shrinkage_flag
):
def
_rescale_output
(
volume
,
scaling_parameters
,
normalize_flag
,
minus_one_scaling_flag
,
negative_flag
,
shrinkage_flag
,
hard_shrinkage_flag
,
cross_domain_x2x_flag
):
"""Output Rescaling
This function reads the scaling factors from the saved file and then scales the data.
...
...
@@ -420,15 +466,23 @@ def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scalin
negative_flag (bool): Flag indicating if all the negative values should be 0-ed.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns:
rescaled_volume (np.array): Rescaled volume
"""
_
,
_
,
min_value
,
max_value
=
scaling_parameters
if
cross_domain_x2x_flag
==
True
:
min_value
,
max_value
,
_
,
_
=
scaling_parameters
else
:
_
,
_
,
min_value
,
max_value
=
scaling_parameters
if
shrinkage_flag
==
True
:
lambd
=
3.0
if
cross_domain_x2x_flag
==
True
:
lambd
=
0.003
else
:
lambd
=
3.0
if
hard_shrinkage_flag
==
True
:
pass
elif
hard_shrinkage_flag
==
False
:
...
...
utils/data_utils.py
View file @
8cbb592b
...
...
@@ -150,7 +150,7 @@ def load_subjects_from_path(data_directory, data_list):
return
volumes_to_be_used
def
load_and_preprocess_evaluation
(
file_path
,
crop_flag
):
def
load_and_preprocess_evaluation
(
file_path
,
crop_flag
,
cross_domain_y2y_flag
):
"""Load & Preprocessing before evaluation
This function loads a nifty file and returns its volume and header information
...
...
@@ -158,6 +158,7 @@ def load_and_preprocess_evaluation(file_path, crop_flag):
Args:
file_path (str): Path to the desired file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag
Returns:
volume (np.array): Array of training image data of data type dtype.
...
...
@@ -170,15 +171,25 @@ def load_and_preprocess_evaluation(file_path, crop_flag):
original_image
=
Image
(
file_path
[
0
])
if
crop_flag
==
False
:
volume
,
xform
=
resampleToPixdims
(
original_image
,
(
2
,
2
,
2
))
header
=
Image
(
volume
,
header
=
original_image
.
header
,
xform
=
xform
).
header
elif
crop_flag
==
True
:
resampled
,
xform
=
resampleToPixdims
(
original_image
,
(
2
,
2
,
2
))
resampled
=
Image
(
resampled
,
header
=
original_image
.
header
,
xform
=
xform
)
cropped
=
roi
(
resampled
,((
9
,
81
),(
10
,
100
),(
0
,
77
)))
volume
=
cropped
.
data
header
=
cropped
.
header
if
cross_domain_y2y_flag
==
True
:
if
crop_flag
==
False
:
volume
=
original_image
.
data
[:,
:,
:,
0
]
header
=
Image
(
volume
,
header
=
original_image
.
header
).
header
elif
crop_flag
==
True
:
cropped
=
roi
(
original_image
,((
9
,
81
),(
10
,
100
),(
0
,
77
)))
volume
=
cropped
.
data
[:,
:,
:,
0
]
header
=
cropped
.
header
xform
=
None
else
:
if
crop_flag
==
False
:
volume
,
xform
=
resampleToPixdims
(
original_image
,
(
2
,
2
,
2
))
header
=
Image
(
volume
,
header
=
original_image
.
header
,
xform
=
xform
).
header
elif
crop_flag
==
True
:
resampled
,
xform
=
resampleToPixdims
(
original_image
,
(
2
,
2
,
2
))
resampled
=
Image
(
resampled
,
header
=
original_image
.
header
,
xform
=
xform
)
cropped
=
roi
(
resampled
,((
9
,
81
),(
10
,
100
),(
0
,
77
)))
volume
=
cropped
.
data
header
=
cropped
.
header
return
volume
,
header
,
xform
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment