Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
815f5223
Commit
815f5223
authored
Sep 08, 2020
by
Andrei Roibu
Browse files
added ability to calculate statistics for test set
parent
01d776d1
Changes
4
Hide whitespace changes
Inline
Side-by-side
run.py
View file @
815f5223
...
...
@@ -411,6 +411,7 @@ def evaluate_data(mapping_evaluation_parameters):
prediction_output_path
=
mapping_evaluation_parameters
[
'prediction_output_path'
]
prediction_output_database_name
=
mapping_evaluation_parameters
[
'prediction_output_database_name'
]
prediction_output_statistics_name
=
mapping_evaluation_parameters
[
'prediction_output_statistics_name'
]
dmri_mean_mask_path
=
mapping_evaluation_parameters
[
'dmri_mean_mask_path'
]
rsfmri_mean_mask_path
=
mapping_evaluation_parameters
[
'rsfmri_mean_mask_path'
]
device
=
mapping_evaluation_parameters
[
'device'
]
...
...
@@ -428,6 +429,7 @@ def evaluate_data(mapping_evaluation_parameters):
shrinkage_flag
=
mapping_evaluation_parameters
[
'shrinkage_flag'
]
hard_shrinkage_flag
=
mapping_evaluation_parameters
[
'hard_shrinkage_flag'
]
crop_flag
=
mapping_evaluation_parameters
[
'crop_flag'
]
output_database_flag
=
mapping_evaluation_parameters
[
'output_database_flag'
]
cross_domain_x2x_flag
=
mapping_evaluation_parameters
[
'cross_domain_x2x_flag'
]
cross_domain_y2y_flag
=
mapping_evaluation_parameters
[
'cross_domain_y2y_flag'
]
...
...
@@ -438,6 +440,7 @@ def evaluate_data(mapping_evaluation_parameters):
data_list
,
prediction_output_path
,
prediction_output_database_name
,
prediction_output_statistics_name
,
brain_mask_path
,
dmri_mean_mask_path
,
rsfmri_mean_mask_path
,
...
...
@@ -455,6 +458,7 @@ def evaluate_data(mapping_evaluation_parameters):
crop_flag
,
device
,
exit_on_error
,
output_database_flag
,
cross_domain_x2x_flag
,
cross_domain_y2y_flag
)
...
...
settings_evaluation.ini
View file @
815f5223
...
...
@@ -2,12 +2,14 @@
trained_model_path
=
"saved_models/VA2-1.pth.tar"
prediction_output_path
=
"VA2-1_predictions"
prediction_output_database_name
=
"output_test_data.h5"
prediction_output_statistics_name
=
"output_statistics.pkl"
data_directory
=
"/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file
=
"dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
mapping_targets_file
=
"fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_list_reduced
=
"datasets/test_reduced.txt"
data_list_all
=
"datasets/test_all.txt"
evaluate_all_data
=
False
output_database_flag
=
False
brain_mask_path
=
"utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path
=
"utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path
=
"utils/mean_tractsNormSummed_downsampled.nii.gz"
...
...
utils/data_evaluation_utils.py
View file @
815f5223
...
...
@@ -26,6 +26,8 @@ import pandas as pd
from
fsl.data.image
import
Image
from
fsl.utils.image.roi
import
roi
import
itertools
from
scipy.spatial.distance
import
cosine
from
scipy.stats
import
pearsonr
,
spearmanr
log
=
logging
.
getLogger
(
__name__
)
...
...
@@ -36,6 +38,7 @@ def evaluate_data(trained_model_path,
data_list
,
prediction_output_path
,
prediction_output_database_name
,
prediction_output_statistics_name
,
brain_mask_path
,
dmri_mean_mask_path
,
rsfmri_mean_mask_path
,
...
...
@@ -53,6 +56,7 @@ def evaluate_data(trained_model_path,
crop_flag
,
device
=
0
,
exit_on_error
=
False
,
output_database_flag
=
False
,
cross_domain_x2x_flag
=
False
,
cross_domain_y2y_flag
=
False
,
mode
=
'evaluate'
):
...
...
@@ -69,6 +73,7 @@ def evaluate_data(trained_model_path,
data_list (str): Path to a .txt file containing the input files for consideration
prediction_output_path (str): Output prediction path
prediction_output_database_name (str): Name of the output database
prediction_output_statistics_name (str): Name of the output statistics database
brain_mask_path (str): Path to the MNI brain mask file
dmri_mean_mask_path (str): Path to the dualreg subject mean mask
rsfmri_mean_mask_path (str): Path to the summed tract mean mask
...
...
@@ -87,6 +92,7 @@ def evaluate_data(trained_model_path,
device (str/int): Device type used for training (int - GPU id, str- CPU)
mode (str): Current run mode or phase
exit_on_error (bool): Flag that triggers the raising of an exception
output_database_flag (bool): Flag indicating if the output maps should be saved to hdf5 database
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
...
...
@@ -129,12 +135,20 @@ def evaluate_data(trained_model_path,
log
.
info
(
"rsfMRI Generation Started"
)
if
cross_domain_y2y_flag
==
True
:
# If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
file_paths
,
volumes_to_be_used
=
data_utils
.
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
=
mapping_targets_file
)
file_paths
,
volumes_to_be_used
=
data_utils
.
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
=
mapping_targets_file
,
mapping_targets_file
=
mapping_targets_file
)
elif
cross_domain_x2x_flag
==
True
:
file_paths
,
volumes_to_be_used
=
data_utils
.
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
,
mapping_targets_file
=
mapping_data_file
)
else
:
file_paths
,
volumes_to_be_used
=
data_utils
.
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
)
file_paths
,
volumes_to_be_used
=
data_utils
.
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
,
mapping_targets_file
)
output_database_path
=
os
.
path
.
join
(
prediction_output_path
,
prediction_output_database_name
)
output_database_handle
=
h5py
.
File
(
output_database_path
,
'w'
)
if
output_database_flag
==
True
:
output_database_path
=
os
.
path
.
join
(
prediction_output_path
,
prediction_output_database_name
)
if
os
.
path
.
exists
(
output_database_path
):
os
.
remove
(
output_database_path
)
output_database_handle
=
h5py
.
File
(
output_database_path
,
'w'
)
output_statistics
=
{}
output_statistics_path
=
os
.
path
.
join
(
prediction_output_path
,
prediction_output_statistics_name
)
with
torch
.
no_grad
():
...
...
@@ -168,12 +182,27 @@ def evaluate_data(trained_model_path,
cross_domain_x2x_flag
,
cross_domain_y2y_flag
)
group
=
output_database_handle
.
create_group
(
subject
)
group
.
create_dataset
(
'predicted_complete_volume'
,
data
=
predicted_complete_volume
)
group
.
create_dataset
(
'predicted_volume'
,
data
=
predicted_volume
)
group
.
create_dataset
(
'header'
,
data
=
header
)
group
.
create_dataset
(
'xform'
,
data
=
xform
)
target_volume
=
_generate_target_volume
(
file_path
,
subject
,
dmri_mean_mask_path
,
rsfmri_mean_mask_path
,
regression_factors
,
mean_regression_flag
,
mean_regression_all_flag
,
mean_subtraction_flag
,
crop_flag
,
cross_domain_x2x_flag
)
mse
,
mae
,
cel
,
pearson_r
,
pearson_p
,
spearman_r
,
spearman_p
,
regression_w
,
regression_b
=
_statistics_calculator
(
predicted_volume
,
target_volume
)
output_statistics
[
subject
]
=
[
mse
,
mae
,
cel
,
pearson_r
,
pearson_p
,
spearman_r
,
spearman_p
,
regression_w
,
regression_b
]
if
output_database_flag
==
True
:
group
=
output_database_handle
.
create_group
(
subject
)
group
.
create_dataset
(
'predicted_complete_volume'
,
data
=
predicted_complete_volume
)
group
.
create_dataset
(
'predicted_volume'
,
data
=
predicted_volume
)
group
.
create_dataset
(
'header'
,
data
=
header
)
group
.
create_dataset
(
'xform'
,
data
=
xform
)
log
.
info
(
"Processed: "
+
volumes_to_be_used
[
volume_index
]
+
" "
+
str
(
volume_index
+
1
)
+
" out of "
+
str
(
len
(
volumes_to_be_used
)))
...
...
@@ -192,10 +221,14 @@ def evaluate_data(trained_model_path,
if
exit_on_error
:
raise
(
exception_expression
)
output_statistics_df
=
pd
.
DataFrame
.
from_dict
(
output_statistics
,
orient
=
'index'
,
columns
=
[
'mse'
,
'mae'
,
'cel'
,
'pearson_r'
,
'pearson_p'
,
'spearman_r'
,
'spearman_p'
,
'regression_w'
,
'regression_b'
])
output_statistics_df
.
to_pickle
(
output_statistics_path
)
log
.
info
(
"Output Data Generation Complete"
)
output_database_handle
.
close
()
def
evaluate_mapping
(
trained_model_path
,
data_directory
,
mapping_data_file
,
...
...
@@ -408,7 +441,7 @@ def _generate_volume_map(file_path,
cross_domain_x2x_flag
,
cross_domain_y2y_flag
):
"""
rsfMRI
Volume Generator
"""
Output
Volume Generator
This function uses the trained model to generate a new volume
...
...
@@ -616,7 +649,7 @@ def _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path,
elif
crop_flag
==
True
:
group_mean
=
roi
(
Image
(
dmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
regressed_volume
=
np
.
subtract
(
volume
,
np
.
multiply
(
weight
,
group_mean
))
regressed_volume
=
np
.
subtract
(
volume
,
np
.
multiply
(
weight
,
group_mean
))
return
regressed_volume
...
...
@@ -716,6 +749,161 @@ def _soft_shrinkage(volume, lambd):
return
volume
def
_generate_target_volume
(
file_path
,
subject
,
dmri_mean_mask_path
,
rsfmri_mean_mask_path
,
regression_factors
,
mean_regression_flag
,
mean_regression_all_flag
,
mean_subtraction_flag
,
crop_flag
,
cross_domain_x2x_flag
):
"""Target Volume Generator
This function loads and preprocesses a target volume for comparing with the network predicted volumes
Args:
file_path (str): Path to the desired file
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the dualreg subject mean mask
regression_factors (str): Path to the linear regression weights file
mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns:
volume (np.array): Array containing the information regarding the target volume
"""
volume
=
data_utils
.
load_and_preprocess_targets
(
file_path
,
crop_flag
,
cross_domain_x2x_flag
)
if
mean_regression_flag
==
True
:
volume
=
_regress_target
(
volume
,
subject
,
dmri_mean_mask_path
,
rsfmri_mean_mask_path
,
regression_factors
,
crop_flag
,
cross_domain_x2x_flag
,
mean_regression_all_flag
)
elif
mean_subtraction_flag
==
True
:
volume
=
_subtract_target
(
volume
,
subject
,
dmri_mean_mask_path
,
rsfmri_mean_mask_path
,
crop_flag
,
cross_domain_x2x_flag
)
return
volume
def
_regress_target
(
volume
,
subject
,
dmri_mean_mask_path
,
rsfmri_mean_mask_path
,
regression_factors
,
crop_flag
,
cross_domain_x2x_flag
,
mean_regression_all_flag
):
""" Target Regression
This function regresse the group mean from the target volume using the saved regression weights.
TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.
Args:
volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
regression_factors (str): Path to the linear regression weights file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets
mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
Returns:
regressed_volume (np.array): Linear regressed volume
"""
if
cross_domain_x2x_flag
==
True
:
if
mean_regression_all_flag
==
True
:
weight
=
pd
.
read_pickle
(
regression_factors
).
loc
[
subject
][
'w_dMRI'
]
if
crop_flag
==
False
:
group_mean
=
Image
(
dmri_mean_mask_path
).
data
elif
crop_flag
==
True
:
group_mean
=
roi
(
Image
(
dmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
regressed_volume
=
np
.
subtract
(
volume
,
np
.
multiply
(
weight
,
group_mean
))
else
:
regressed_volume
=
volume
else
:
weight
=
pd
.
read_pickle
(
regression_factors
).
loc
[
subject
][
'w_rsfMRI'
]
if
crop_flag
==
False
:
group_mean
=
Image
(
rsfmri_mean_mask_path
).
data
[:,
:,
:,
0
]
elif
crop_flag
==
True
:
group_mean
=
roi
(
Image
(
rsfmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
[:,
:,
:,
0
]
regressed_volume
=
np
.
subtract
(
volume
,
np
.
multiply
(
weight
,
group_mean
))
return
regressed_volume
def
_subtract_target
(
volume
,
subject
,
dmri_mean_mask_path
,
rsfmri_mean_mask_path
,
crop_flag
,
cross_domain_x2x_flag
):
""" Target Subtraction
This function subtracts the group mean from the target volume using the saved regression weights.
TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.
Args:
volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
regressed_volume (np.array): Linear regressed volume
"""
if
cross_domain_x2x_flag
==
True
:
subtracted_volume
=
volume
else
:
if
crop_flag
==
False
:
group_mean
=
Image
(
rsfmri_mean_mask_path
).
data
[:,
:,
:,
0
]
elif
crop_flag
==
True
:
group_mean
=
roi
(
Image
(
rsfmri_mean_mask_path
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
[:,
:,
:,
0
]
subtracted_volume
=
np
.
subtract
(
volume
,
group_mean
)
return
subtracted_volume
def
_statistics_calculator
(
volume
,
target
):
""" Training statistics calculator
This function calculates the MSE, MAE, CEL, Pearson R and P and linear regression W and B for a predicted volume and it's target ground truth.
Args:
volume (np.array): Predicted volume
target (np.array): Ground truth volume
Returns:
mse (np.float64): The mean squared error between the prediction and the ground truth; The closer to 0, the better
mae (np.float64): The mean absolut error between the prediction and the ground truth; The closer to 0, the better
cel (np.float64): The cosine distance between the prediction and the ground truth; The closer to 0, the better
pearson_r (np.float64): Pearson’s correlation coefficient; The closer to 1, the better
pearson_p (np.float64): Two-tailed p-value for Pearson’s correlation coefficient; the closer to 0, the better
spearman_r (np.float64): Spearman correlation coefficient; The closer to 1, the better
spearman_p (np.float64): Two-tailed p-value for Spearman's correlation coefficient; the closer to 0, the better
regression_w (np.float64): Slope of the linear regression line; The closer to 1 the better
regression_b (np.float64): Intersect of the linear regression line; The closer to 0, the better
"""
x
=
np
.
reshape
(
volume
,
-
1
)
y
=
np
.
reshape
(
target
,
-
1
)
mse
=
np
.
square
(
np
.
subtract
(
x
,
y
)).
mean
()
mae
=
np
.
abs
(
np
.
subtract
(
x
,
y
)).
mean
()
cel
=
np
.
mean
(
cosine
(
x
,
y
))
pearson_r
,
pearson_p
=
pearsonr
(
x
,
y
)
spearman_r
,
spearman_p
=
spearmanr
(
x
,
y
)
x_matrix
=
np
.
vstack
((
np
.
ones
(
len
(
x
)),
x
)).
T
regression_b
,
regression_w
=
np
.
linalg
.
inv
(
x_matrix
.
T
.
dot
(
x_matrix
)).
dot
(
x_matrix
.
T
).
dot
(
y
)
return
mse
,
mae
,
cel
,
pearson_r
,
pearson_p
,
spearman_r
,
spearman_p
,
regression_w
,
regression_b
def
_pearson_correlation
(
volume
,
target
):
"""Calculate Pearson Correlation Coefficient
...
...
@@ -733,3 +921,4 @@ def _pearson_correlation(volume, target):
np
.
sum
(
np
.
power
(
np
.
subtract
(
volume
,
volume
.
mean
()),
2
)),
np
.
sum
(
np
.
power
(
np
.
subtract
(
target
,
target
.
mean
()),
2
))))
return
r
utils/data_utils.py
View file @
815f5223
...
...
@@ -96,7 +96,7 @@ def get_datasets(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag):
)
def
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
,
targets_directory
=
None
,
target_file
=
None
):
def
load_file_paths
(
data_directory
,
data_list
,
mapping_data_file
,
mapping_
target
s
_file
=
None
):
"""File Loader
This function returns a list of combined file paths for the input and output data.
...
...
@@ -105,7 +105,7 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct
data_directory (str): Path to input data directory
data_list (str): Path to a .txt file containing the input files for consideration
mapping_data_file (str): Path to the input files
targets_directory (str): Path to labelled data (Y-equivalent); None if during evaluation.
mapping_targets_file (str): Path to the target files
Returns:
file_paths (list): List containing the input data and target labelled output data
...
...
@@ -117,12 +117,12 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct
volumes_to_be_used
=
load_subjects_from_path
(
data_directory
,
data_list
)
if
targets_directory
==
None
or
target_file
==
None
:
if
mapping_
target
s
_file
==
None
:
file_paths
=
[[
os
.
path
.
join
(
data_directory
,
volume
,
mapping_data_file
)]
for
volume
in
volumes_to_be_used
]
else
:
file_paths
=
[[
os
.
path
.
join
(
data_directory
,
volume
,
mapping_data_file
),
os
.
path
.
join
(
targets_directory
,
volum
e
)]
for
volume
in
volumes_to_be_used
]
file_paths
=
[[
os
.
path
.
join
(
data_directory
,
volume
,
mapping_data_file
),
os
.
path
.
join
(
data_directory
,
volume
,
mapping_targets_fil
e
)]
for
volume
in
volumes_to_be_used
]
return
file_paths
,
volumes_to_be_used
...
...
@@ -158,10 +158,10 @@ def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag):
Args:
file_path (str): Path to the desired file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag
cross_domain_y2y_flag
(bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
volume (np.array): Array of training image data
of data type dtype.
volume (np.array): Array of training image data
header (class): 'nibabel.nifti1.Nifti1Header' class object, containing image metadata
xform (np.array): Array of shape (4, 4), containing the adjusted voxel-to-world transformation for the spatial dimensions of the resampled data
...
...
@@ -194,27 +194,35 @@ def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag):
return
volume
,
header
,
xform
def
load_and_preprocess_targets
(
target
_path
,
mean_mask_path
):
def
load_and_preprocess_targets
(
file
_path
,
crop_flag
,
cross_domain_x2x_flag
):
"""Load & Preprocessing targets before evaluation
This function loads a nifty file and returns its volume
, a de-meaned volume and header
information
This function loads a nifty file and returns its volume information
Args:
file_path (str): Path to the desired target file
mean_mask_path (str): Path to the dualreg subject mean mask
file_path (str): Path to the desired file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns:
target (np.array): Array of training image data of data type dtype.
target_demeaned (np.array): Array of training data from which the group mean has been subtracted
volume (np.array): Array of target image intensities.
Raises:
ValueError: "Orientation value is invalid. It must be either >>coronal<<, >>axial<< or >>sagital<< "
"""
"""
original_image
=
Image
(
file_path
[
1
])
target
=
Image
(
target_path
[
0
]).
data
[:,
:,
:,
0
]
target_demeaned
=
np
.
subtract
(
target
,
Image
(
mean_mask_path
).
data
[:,
:,
:,
0
])
if
cross_domain_x2x_flag
==
True
:
if
crop_flag
==
False
:
volume
,
_
=
resampleToPixdims
(
original_image
,
(
2
,
2
,
2
))
elif
crop_flag
==
True
:
resampled
,
xform
=
resampleToPixdims
(
original_image
,
(
2
,
2
,
2
))
volume
=
roi
(
Image
(
resampled
,
header
=
original_image
.
header
,
xform
=
xform
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
else
:
if
crop_flag
==
False
:
volume
=
original_image
.
data
[:,
:,
:,
0
]
elif
crop_flag
==
True
:
volume
=
roi
(
original_image
,((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
[:,
:,
:,
0
]
return
target
,
target_demeaned
return
volume
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment