Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
15ffbc18
Commit
15ffbc18
authored
May 07, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
added input/output scaling, fixed overfitting network save, delete old checkpoints
parent
7c0603da
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
15ffbc18
...
...
@@ -119,4 +119,5 @@ dmypy.json
datasets/
files.txt
jobscript.sge.sh
*.nii.gz
\ No newline at end of file
*.nii.gz
stuff/
\ No newline at end of file
BrainMapperUNet.py
View file @
15ffbc18
...
...
@@ -20,8 +20,8 @@ import torch.nn as nn
import
utils.modules
as
modules
class
BrainMapper
CompRes
UNet3D
(
nn
.
Module
):
"""Architecture class for
Competitive Residual DenseBlock
BrainMapper 3D U-net.
class
BrainMapperUNet3D
(
nn
.
Module
):
"""Architecture class for
Traditional
BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
...
...
@@ -47,25 +47,41 @@ class BrainMapperCompResUNet3D(nn.Module):
"""
def
__init__
(
self
,
parameters
):
super
(
BrainMapper
CompRes
UNet3D
,
self
).
__init__
()
super
(
BrainMapperUNet3D
,
self
).
__init__
()
original_input_channels
=
parameters
[
'input_channels'
]
original_output_channels
=
parameters
[
'output_channels'
]
self
.
encoderBlock1
=
modules
.
InCompDens
EncoderBlock3D
(
parameters
)
self
.
encoderBlock1
=
modules
.
EncoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
self
.
encoderBlock2
=
modules
.
CompDensEncoderBlock3D
(
parameters
)
self
.
encoderBlock3
=
modules
.
CompDensEncoderBlock3D
(
parameters
)
self
.
encoderBlock4
=
modules
.
CompDensEncoderBlock3D
(
parameters
)
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
encoderBlock2
=
modules
.
EncoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
encoderBlock3
=
modules
.
EncoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
encoderBlock4
=
modules
.
EncoderBlock3D
(
parameters
)
self
.
bottleneck
=
modules
.
CompDensBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
bottleneck
=
modules
.
ConvolutionalBlock3D
(
parameters
)
self
.
decoderBlock1
=
modules
.
CompDensDecoderBlock3D
(
parameters
)
self
.
decoderBlock2
=
modules
.
CompDensDecoderBlock3D
(
parameters
)
self
.
decoderBlock3
=
modules
.
CompDensDecoderBlock3D
(
parameters
)
self
.
decoderBlock4
=
modules
.
CompDensDecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
//
2
self
.
decoderBlock1
=
modules
.
DecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
//
2
self
.
decoderBlock2
=
modules
.
DecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
//
2
self
.
decoderBlock3
=
modules
.
DecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
//
2
self
.
decoderBlock4
=
modules
.
DecoderBlock3D
(
parameters
)
self
.
classifier
=
modules
.
DensClassifierBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
self
.
classifier
=
modules
.
ClassifierBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
original_input_channels
parameters
[
'output_channels'
]
=
original_output_channels
...
...
@@ -211,9 +227,8 @@ class BrainMapperCompResUNet3D(nn.Module):
print
(
"Initialized network parameters!"
)
class
BrainMapperResUNet3Dshallow
(
nn
.
Module
):
"""Architecture class for Residual DenseBlock BrainMapper 3D U-net.
class
BrainMapperCompResUNet3D
(
nn
.
Module
):
"""Architecture class for Competitive Residual DenseBlock BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
...
...
@@ -239,25 +254,25 @@ class BrainMapperResUNet3Dshallow(nn.Module):
"""
def
__init__
(
self
,
parameters
):
super
(
BrainMapperResUNet3D
shallow
,
self
).
__init__
()
super
(
BrainMapper
Comp
ResUNet3D
,
self
).
__init__
()
original_input_channels
=
parameters
[
'input_channels'
]
original_output_channels
=
parameters
[
'output_channels'
]
self
.
encoderBlock1
=
modules
.
DensEncoderBlock3D
(
parameters
)
self
.
encoderBlock1
=
modules
.
InComp
DensEncoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
self
.
encoderBlock2
=
modules
.
DensEncoderBlock3D
(
parameters
)
self
.
encoderBlock3
=
modules
.
DensEncoderBlock3D
(
parameters
)
self
.
encoderBlock2
=
modules
.
CompDensEncoderBlock3D
(
parameters
)
self
.
encoderBlock3
=
modules
.
CompDensEncoderBlock3D
(
parameters
)
self
.
encoderBlock4
=
modules
.
CompDensEncoderBlock3D
(
parameters
)
self
.
bottleneck
=
modules
.
DensBlock3D
(
parameters
)
self
.
bottleneck
=
modules
.
Comp
DensBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
decoderBlock
1
=
modules
.
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock
2
=
modules
.
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock
3
=
modules
.
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock1
=
modules
.
CompDensDecoderBlock3D
(
parameters
)
self
.
decoderBlock
2
=
modules
.
Comp
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock
3
=
modules
.
Comp
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock
4
=
modules
.
Comp
DensDecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
self
.
classifier
=
modules
.
DensClassifierBlock3D
(
parameters
)
self
.
classifier
=
modules
.
CompDensClassifierBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
original_input_channels
parameters
[
'output_channels'
]
=
original_output_channels
...
...
@@ -286,28 +301,38 @@ class BrainMapperResUNet3Dshallow(nn.Module):
del
Y_encoder_2
Y_bottleNeck
=
self
.
bottleneck
.
forward
(
Y_encoder_3
)
Y_encoder_4
,
Y_np4
,
_
=
self
.
encoderBlock4
.
forward
(
Y_encoder_3
)
del
Y_encoder_3
Y_bottleNeck
=
self
.
bottleneck
.
forward
(
Y_encoder_4
)
del
Y_encoder_4
Y_decoder_1
=
self
.
decoderBlock1
.
forward
(
Y_bottleNeck
,
Y_np
3
)
Y_bottleNeck
,
Y_np
4
)
del
Y_bottleNeck
,
Y_np
3
del
Y_bottleNeck
,
Y_np
4
Y_decoder_2
=
self
.
decoderBlock2
.
forward
(
Y_decoder_1
,
Y_np
2
)
Y_decoder_1
,
Y_np
3
)
del
Y_decoder_1
,
Y_np
2
del
Y_decoder_1
,
Y_np
3
Y_decoder_3
=
self
.
decoderBlock3
.
forward
(
Y_decoder_2
,
Y_np
1
)
Y_decoder_2
,
Y_np
2
)
del
Y_decoder_2
,
Y_np
1
del
Y_decoder_2
,
Y_np
2
probability_map
=
self
.
classifier
.
forward
(
Y_decoder_3
)
Y_decoder_4
=
self
.
decoderBlock4
.
forward
(
Y_decoder_3
,
Y_np1
)
del
Y_decoder_3
del
Y_decoder_3
,
Y_np1
probability_map
=
self
.
classifier
.
forward
(
Y_decoder_4
)
del
Y_decoder_4
return
probability_map
...
...
@@ -393,7 +418,7 @@ class BrainMapperResUNet3Dshallow(nn.Module):
print
(
"Initialized network parameters!"
)
class
BrainMapperResUNet3D
(
nn
.
Module
):
class
BrainMapperResUNet3D
shallow
(
nn
.
Module
):
"""Architecture class for Residual DenseBlock BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
...
...
@@ -420,7 +445,7 @@ class BrainMapperResUNet3D(nn.Module):
"""
def
__init__
(
self
,
parameters
):
super
(
BrainMapperResUNet3D
,
self
).
__init__
()
super
(
BrainMapperResUNet3D
shallow
,
self
).
__init__
()
original_input_channels
=
parameters
[
'input_channels'
]
original_output_channels
=
parameters
[
'output_channels'
]
...
...
@@ -429,7 +454,6 @@ class BrainMapperResUNet3D(nn.Module):
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
self
.
encoderBlock2
=
modules
.
DensEncoderBlock3D
(
parameters
)
self
.
encoderBlock3
=
modules
.
DensEncoderBlock3D
(
parameters
)
self
.
encoderBlock4
=
modules
.
DensEncoderBlock3D
(
parameters
)
self
.
bottleneck
=
modules
.
DensBlock3D
(
parameters
)
...
...
@@ -437,7 +461,6 @@ class BrainMapperResUNet3D(nn.Module):
self
.
decoderBlock1
=
modules
.
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock2
=
modules
.
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock3
=
modules
.
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock4
=
modules
.
DensDecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
self
.
classifier
=
modules
.
DensClassifierBlock3D
(
parameters
)
...
...
@@ -469,38 +492,28 @@ class BrainMapperResUNet3D(nn.Module):
del
Y_encoder_2
Y_encoder_4
,
Y_np4
,
_
=
self
.
encoderBlock4
.
forward
(
Y_encoder_3
)
Y_bottleNeck
=
self
.
bottleneck
.
forward
(
Y_encoder_3
)
del
Y_encoder_3
Y_bottleNeck
=
self
.
bottleneck
.
forward
(
Y_encoder_4
)
del
Y_encoder_4
Y_decoder_1
=
self
.
decoderBlock1
.
forward
(
Y_bottleNeck
,
Y_np
4
)
Y_bottleNeck
,
Y_np
3
)
del
Y_bottleNeck
,
Y_np
4
del
Y_bottleNeck
,
Y_np
3
Y_decoder_2
=
self
.
decoderBlock2
.
forward
(
Y_decoder_1
,
Y_np
3
)
Y_decoder_1
,
Y_np
2
)
del
Y_decoder_1
,
Y_np
3
del
Y_decoder_1
,
Y_np
2
Y_decoder_3
=
self
.
decoderBlock3
.
forward
(
Y_decoder_2
,
Y_np2
)
del
Y_decoder_2
,
Y_np2
Y_decoder_4
=
self
.
decoderBlock4
.
forward
(
Y_decoder_3
,
Y_np1
)
Y_decoder_2
,
Y_np1
)
del
Y_decoder_
3
,
Y_np1
del
Y_decoder_
2
,
Y_np1
probability_map
=
self
.
classifier
.
forward
(
Y_decoder_
4
)
probability_map
=
self
.
classifier
.
forward
(
Y_decoder_
3
)
del
Y_decoder_
4
del
Y_decoder_
3
return
probability_map
...
...
@@ -586,8 +599,8 @@ class BrainMapperResUNet3D(nn.Module):
print
(
"Initialized network parameters!"
)
class
BrainMapperUNet3D
(
nn
.
Module
):
"""Architecture class for
Traditional
BrainMapper 3D U-net.
class
BrainMapper
Res
UNet3D
(
nn
.
Module
):
"""Architecture class for
Residual DenseBlock
BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
...
...
@@ -613,41 +626,27 @@ class BrainMapperUNet3D(nn.Module):
"""
def
__init__
(
self
,
parameters
):
super
(
BrainMapperUNet3D
,
self
).
__init__
()
super
(
BrainMapper
Res
UNet3D
,
self
).
__init__
()
original_input_channels
=
parameters
[
'input_channels'
]
original_output_channels
=
parameters
[
'output_channels'
]
self
.
encoderBlock1
=
modules
.
EncoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
encoderBlock2
=
modules
.
EncoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
encoderBlock3
=
modules
.
EncoderBlock3D
(
parameters
)
self
.
encoderBlock1
=
modules
.
DensEncoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
encoderBlock4
=
modules
.
EncoderBlock3D
(
parameters
)
self
.
encoderBlock2
=
modules
.
DensEncoderBlock3D
(
parameters
)
self
.
encoderBlock3
=
modules
.
DensEncoderBlock3D
(
parameters
)
self
.
encoderBlock4
=
modules
.
DensEncoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
bottleneck
=
modules
.
ConvolutionalBlock3D
(
parameters
)
self
.
bottleneck
=
modules
.
DensBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
//
2
self
.
decoderBlock1
=
modules
.
DecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
//
2
self
.
decoderBlock2
=
modules
.
DecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
//
2
self
.
decoderBlock3
=
modules
.
DecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
//
2
self
.
decoderBlock4
=
modules
.
DecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
*
2
self
.
decoderBlock1
=
modules
.
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock2
=
modules
.
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock3
=
modules
.
DensDecoderBlock3D
(
parameters
)
self
.
decoderBlock4
=
modules
.
DensDecoderBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
self
.
classifier
=
modules
.
ClassifierBlock3D
(
parameters
)
self
.
classifier
=
modules
.
Dens
ClassifierBlock3D
(
parameters
)
parameters
[
'input_channels'
]
=
original_input_channels
parameters
[
'output_channels'
]
=
original_output_channels
...
...
run.py
View file @
15ffbc18
...
...
@@ -149,11 +149,10 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
BrainMapperModel
=
torch
.
load
(
training_parameters
[
'pre_trained_path'
])
else
:
#
BrainMapperModel = BrainMapperUNet3D(network_parameters)
BrainMapperModel
=
BrainMapperUNet3D
(
network_parameters
)
# BrainMapperModel = BrainMapperResUNet3D(network_parameters)
# BrainMapperModel = BrainMapperResUNet3Dshallow(network_parameters)
BrainMapperModel
=
BrainMapperCompResUNet3D
(
network_parameters
)
# BrainMapperModel = BrainMapperCompResUNet3D(network_parameters)
BrainMapperModel
.
reset_parameters
()
...
...
@@ -178,20 +177,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
use_last_checkpoint
=
training_parameters
[
'use_last_checkpoint'
],
experiment_directory
=
misc_parameters
[
'experiments_directory'
],
logs_directory
=
misc_parameters
[
'logs_directory'
],
checkpoint_directory
=
misc_parameters
[
'checkpoint_directory'
]
checkpoint_directory
=
misc_parameters
[
'checkpoint_directory'
],
save_model_directory
=
misc_parameters
[
'save_model_directory'
],
final_model_output_file
=
training_parameters
[
'final_model_output_file'
]
)
validation_loss
=
solver
.
train
(
train_loader
,
validation_loader
)
model_output_path
=
os
.
path
.
join
(
misc_parameters
[
'save_model_directory'
],
training_parameters
[
'final_model_output_file'
])
create_folder
(
misc_parameters
[
'save_model_directory'
])
BrainMapperModel
.
save
(
model_output_path
)
print
(
"Final Model Saved in: {}"
.
format
(
model_output_path
))
del
train_data
,
validation_data
,
train_loader
,
validation_loader
,
BrainMapperModel
,
solver
,
optimizer
torch
.
cuda
.
empty_cache
()
...
...
@@ -283,7 +275,7 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
prediction_output_path
=
prediction_output_path
,
device
=
misc_parameters
[
'device'
],
LogWriter
=
logWriter
)
)
logWriter
.
close
()
...
...
@@ -298,7 +290,6 @@ def evaluate_mapping(mapping_evaluation_parameters):
mapping_evaluation_parameters = {
'trained_model_path': 'path/to/model'
'data_directory': 'path/to/data'
'mapping_data_file': 'path/to/file'
'data_list': 'path/to/datalist.txt/
'prediction_output_path': 'directory-of-saved-predictions'
'batch_size': 2
...
...
@@ -317,6 +308,7 @@ def evaluate_mapping(mapping_evaluation_parameters):
brain_mask_path
=
mapping_evaluation_parameters
[
'brain_mask_path'
]
mean_mask_path
=
mapping_evaluation_parameters
[
'mean_mask_path'
]
mean_reduction
=
mapping_evaluation_parameters
[
'mean_reduction'
]
scaling_factors
=
mapping_evaluation_parameters
[
'scaling_factors'
]
evaluations
.
evaluate_mapping
(
trained_model_path
,
data_directory
,
...
...
@@ -326,6 +318,7 @@ def evaluate_mapping(mapping_evaluation_parameters):
brain_mask_path
,
mean_mask_path
,
mean_reduction
,
scaling_factors
,
device
=
device
,
exit_on_error
=
exit_on_error
)
...
...
@@ -369,12 +362,19 @@ if __name__ == '__main__':
# Here we shuffle the data!
if
data_parameters
[
'data_split_flag'
]
==
True
:
print
(
'Data is shuffling... This could take a few minutes!'
)
if
data_parameters
[
'data_split_flag'
]
==
True
:
if
data_parameters
[
'use_data_file'
]
==
True
:
data_test_train_validation_split
(
data_parameters
[
'data_folder_name'
],
data_parameters
[
'test_percentage'
],
data_parameters
[
'subject_number'
],
data_directory
=
data_parameters
[
'data_directory'
],
train_inputs
=
data_parameters
[
'train_data_file'
],
train_targets
=
data_parameters
[
'train_output_targets'
],
mean_mask_path
=
data_parameters
[
'mean_mask_path'
],
data_file
=
data_parameters
[
'data_file'
],
K_fold
=
data_parameters
[
'k_fold'
]
)
...
...
@@ -383,10 +383,15 @@ if __name__ == '__main__':
data_parameters
[
'test_percentage'
],
data_parameters
[
'subject_number'
],
data_directory
=
data_parameters
[
'data_directory'
],
train_inputs
=
data_parameters
[
'train_data_file'
],
train_targets
=
data_parameters
[
'train_output_targets'
],
mean_mask_path
=
data_parameters
[
'mean_mask_path'
],
K_fold
=
data_parameters
[
'k_fold'
]
)
update_shuffling_flag
(
'settings.ini'
)
print
(
'Data is shuffling... Complete!'
)
if
arguments
.
mode
==
'train'
:
train
(
data_parameters
,
training_parameters
,
network_parameters
,
misc_parameters
)
...
...
settings.ini
View file @
15ffbc18
...
...
@@ -10,6 +10,7 @@ subject_number = None
train_list
=
"datasets/train.txt"
validation_list
=
"datasets/validation.txt"
test_list
=
"datasets/test.txt"
scaling_factors
=
"datasets/scaling_factors.pkl"
train_data_file
=
"dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
train_output_targets
=
"fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
validation_data_file
=
"dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
...
...
settings_evaluation.ini
View file @
15ffbc18
...
...
@@ -6,6 +6,7 @@ data_list = "datasets/test.txt"
prediction_output_path
=
"network_predictions"
brain_mask_path
=
"utils/MNI152_T1_2mm_brain_mask.nii.gz"
mean_mask_path
=
"utils/mean_dr_stage2.nii.gz"
scaling_factors
=
"datasets/scaling_factors.pkl"
mean_reduction
=
True
device
=
0
exit_on_error
=
True
\ No newline at end of file
exit_on_error
=
True
solver.py
View file @
15ffbc18
...
...
@@ -72,7 +72,9 @@ class Solver():
use_last_checkpoint
=
True
,
experiment_directory
=
'experiments'
,
logs_directory
=
'logs'
,
checkpoint_directory
=
'checkpoints'
checkpoint_directory
=
'checkpoints'
,
save_model_directory
=
'saved_models'
,
final_model_output_file
=
'finetuned_alldata.pth.tar'
):
self
.
model
=
model
...
...
@@ -125,6 +127,9 @@ class Solver():
self
.
MNI152_T1_2mm_brain_mask
=
torch
.
from_numpy
(
Image
(
'utils/MNI152_T1_2mm_brain_mask.nii.gz'
).
data
)
self
.
save_model_directory
=
save_model_directory
self
.
final_model_output_file
=
final_model_output_file
def
train
(
self
,
train_loader
,
validation_loader
):
"""Training Function
...
...
@@ -145,6 +150,8 @@ class Solver():
torch
.
cuda
.
empty_cache
()
# clear memory
model
.
cuda
(
self
.
device
)
# Moving the model to GPU
previous_checkpoint
=
None
print
(
'****************************************************************'
)
print
(
'TRAINING IS STARTING!'
)
print
(
'====================='
)
...
...
@@ -231,6 +238,8 @@ class Solver():
self
.
early_stop
=
early_stop
if
save_checkpoint
==
True
:
validation_loss
=
np
.
mean
(
losses
)
checkpoint_name
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
self
.
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
self
.
save_checkpoint
(
state
=
{
'epoch'
:
epoch
+
1
,
'start_iteration'
:
iteration
+
1
,
'arch'
:
self
.
model_name
,
...
...
@@ -238,12 +247,14 @@ class Solver():
'optimizer'
:
optimizer
.
state_dict
(),
'scheduler'
:
learning_rate_scheduler
.
state_dict
()
},
filename
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
self
.
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
filename
=
checkpoint_name
)
# if epoch != self.start_epoch:
# os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
# 'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
if
previous_checkpoint
!=
None
:
os
.
remove
(
previous_checkpoint
)
previous_checkpoint
=
checkpoint_name
else
:
previous_checkpoint
=
checkpoint_name
if
phase
==
'train'
:
learning_rate_scheduler
.
step
()
...
...
@@ -254,10 +265,18 @@ class Solver():
if
self
.
early_stop
==
True
:
print
(
"ATTENTION!: Training stopped early to prevent overfitting!"
)
self
.
load_checkpoint
()
break
else
:
continue
model_output_path
=
os
.
path
.
join
(
self
.
save_model_directory
,
self
.
final_model_output_file
)
create_folder
(
self
.
save_model_directory
)
model
.
save
(
model_output_path
)
self
.
LogWriter
.
close
()
print
(
'----------------------------------------'
)
...
...
@@ -266,6 +285,7 @@ class Solver():
end_time
=
datetime
.
now
()
print
(
'Completed At: {}'
.
format
(
end_time
))
print
(
'Training Duration: {}'
.
format
(
end_time
-
start_time
))
print
(
'Final Model Saved in: {}'
.
format
(
model_output_path
))
print
(
'****************************************************************'
)
return
validation_loss
...
...
utils/data_evaluation_utils.py
View file @
15ffbc18
...
...
@@ -15,6 +15,7 @@ TODO: Might be worth adding some information on uncertaintiy estimation, later d
"""
import
os
import
pickle
import
numpy
as
np
import
torch
import
logging
...
...
@@ -209,6 +210,7 @@ def evaluate_mapping(trained_model_path,
brain_mask_path
,
mean_mask_path
,
mean_reduction
,
scaling_factors
,
device
=
0
,
mode
=
'evaluate'
,
exit_on_error
=
False
):
...
...
@@ -225,6 +227,7 @@ def evaluate_mapping(trained_model_path,
brain_mask_path (str): Path to the MNI brain mask file
mean_mask_path (str): Path to the dualreg subject mean mask
mean_reduction (bool): Flag indicating if the targets should be de-meaned using the mean_mask_path
scaling_factors (str): Path to the scaling factor file
device (str/int): Device type used for training (int - GPU id, str- CPU)
mode (str): Current run mode or phase
exit_on_error (bool): Flag that triggers the raising of an exception
...
...
@@ -276,7 +279,7 @@ def evaluate_mapping(trained_model_path,
print
(
"Mapping Volume {}/{}"
.
format
(
volume_index
+
1
,
len
(
file_paths
)))
# Generate volume & header
predicted_complete_volume
,
predicted_volume
,
header
,
xform
=
_generate_volume_map
(
file_path
,
model
,
device
,
cuda_available
,
brain_mask_path
,
mean_mask_path
,
mean_reduction
)
file_path
,
model
,
device
,
cuda_available
,
brain_mask_path
,
mean_mask_path
,
scaling_factors
,
mean_reduction
)
# Generate New Header Affine
...
...
@@ -298,12 +301,14 @@ def evaluate_mapping(trained_model_path,
output_complete_nifti_image
=
Image
(
predicted_complete_volume
,
header
=
header
,
xform
=
xform
)
output_complete_nifti_path
=
output_nifti_path
+
'_complete'
output_complete_nifti_path
=
os
.
path
.
join
(
prediction_output_path
,
volumes_to_be_used
[
volume_index
])
+
'_complete'
if
'.nii'
not
in
output_complete_nifti_path
:
output_complete_nifti_path
+=
'.nii.gz'
output_complete_nifti_image
.
save
(
output_complete_nifti_path
)
output_complete_nifti_image
.
save
(
output_complete_nifti_path
)
log
.
info
(
"Processed: "
+
volumes_to_be_used
[
volume_index
]
+
" "
+
str
(
volume_index
+
1
)
+
" out of "
+
str
(
len
(
volumes_to_be_used
)))
...
...
@@ -323,7 +328,7 @@ def evaluate_mapping(trained_model_path,
log
.
info
(
"rsfMRI Generation Complete"
)
def
_generate_volume_map
(
file_path
,
model
,
device
,
cuda_available
,
brain_mask_path
,
mean_mask_path
,
mean_reduction
=
False
):
def
_generate_volume_map
(
file_path
,
model
,
device
,
cuda_available
,
brain_mask_path
,
mean_mask_path
,
scaling_factors
,
mean_reduction
=
False
):
"""rsfMRI Volume Generator
This function uses the trained model to generate a new volume
...
...
@@ -335,6 +340,7 @@ def _generate_volume_map(file_path, model, device, cuda_available, brain_mask_pa
cuda_available (bool): Flag indicating if a cuda-enabled GPU is present
brain_mask_path (str): Path to the MNI brain mask file
mean_mask_path (str): Path to the dualreg subject mean mask
scaling_factors (str): Path to the scaling factor file
mean_reduction (bool): Flag indicating if the targets should be de-meaned using the mean_mask_path
Returns
...
...
@@ -345,41 +351,87 @@ def _generate_volume_map(file_path, model, device, cuda_available, brain_mask_pa
volume
,
header
,
xform
=
data_utils
.
load_and_preprocess_evaluation
(
file_path
)
if
len
(
volume
.
shape
)
==
4
:
if
len
(
volume
.
shape
)
==
5
:
volume
=
volume
else
:
volume
=
volume
[
np
.
newaxis
,
np
.
newaxis
,
:,
:,
:]
volume
=
_scale_input
(
volume
,
scaling_factors
)
volume
=
torch
.
tensor
(
volume
).
type
(
torch
.
FloatTensor
)
MNI152_T1_2mm_brain_mask
=
torch
.
from_numpy
(
Image
(
brain_mask_path
).
data
)
if
mean_reduction
==
True
:
mean_mask
=
torch
.
from_numpy
(
Image
(
mean_mask_path
).
data
[:,
:,
:,
0
])
if
cuda_available
and
(
type
(
device
)
==
int
):
volume
=
volume
.
cuda
(
device
)
MNI152_T1_2mm_brain_mask
=
MNI152_T1_2mm_brain_mask
.
cuda
(
device
)
if
mean_reduction
==
True
:
mean_mask
=
mean_mask
.
cuda
(
device
)
output
=
model
(
volume
)
output
=
torch
.
mul
(
output
,
MNI152_T1_2mm_brain_mask
)
output
=
(
output
.
cpu
().
numpy
()).
astype
(
'float32'
)
output
=
np
.
squeeze
(
output
)
output
=
_rescale_output
(
output
,
scaling_factors
)
if
mean_reduction
==
True
:
predicted_complete_volume
=
torch
.
add
(
output
,
mean_mask
)
predicted_complete_volume
=
(
predicted_complete_volume
.
cpu
().
numpy
()).
astype
(
'float32'
)
predicted_complete_volume
=
np
.
squeeze
(
predicted_complete_volume
)
MNI152_T1_2mm_brain_mask
=
Image
(
brain_mask_path
).
data
if
mean_reduction
==
True
:
mean_mask
=
Image
(
mean_mask_path
).
data
[:,
:,
:,
0
]