Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
7c0603da
Commit
7c0603da
authored
May 05, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
fixed data logging bugs
parent
12ce71d4
Changes
3
Show whitespace changes
Inline
Side-by-side
run.py
View file @
7c0603da
...
...
@@ -277,8 +277,7 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
_
=
evaluations
.
evaluate_dice_score
(
trained_model_path
=
evaluation_parameters
[
'trained_model_path'
],
number_of_classes
=
network_parameters
[
'number_of_classes'
],
data_directory
=
evaluation_parameters
[
'data_directory'
],
targets_directory
=
evaluation_parameters
[
'targets_directory'
],
targets_directory
=
evaluation_parameters
[
'targets_directory'
],
data_list
=
evaluation_parameters
[
'data_list'
],
orientation
=
evaluation_parameters
[
'orientation'
],
prediction_output_path
=
prediction_output_path
,
...
...
solver.py
View file @
7c0603da
...
...
@@ -122,7 +122,8 @@ class Solver():
if
use_last_checkpoint
:
self
.
load_checkpoint
()
self
.
MNI152_T1_2mm_brain_mask
=
torch
.
from_numpy
(
Image
(
'utils/MNI152_T1_2mm_brain_mask.nii.gz'
).
data
)
self
.
MNI152_T1_2mm_brain_mask
=
torch
.
from_numpy
(
Image
(
'utils/MNI152_T1_2mm_brain_mask.nii.gz'
).
data
)
def
train
(
self
,
train_loader
,
validation_loader
):
"""Training Function
...
...
@@ -174,10 +175,6 @@ class Solver():
for
batch_index
,
sampled_batch
in
enumerate
(
dataloaders
[
phase
]):
X
=
sampled_batch
[
0
].
type
(
torch
.
FloatTensor
)
# X = ( X - X.min() ) / ( X.max() - X.min() )
# X = ( X - X.mean() ) / X.std()
y
=
sampled_batch
[
1
].
type
(
torch
.
FloatTensor
)
# We add an extra dimension (~ number of channels) for the 3D convolutions.
...
...
@@ -189,7 +186,8 @@ class Solver():
if
model
.
test_if_cuda
:
X
=
X
.
cuda
(
self
.
device
,
non_blocking
=
True
)
y
=
y
.
cuda
(
self
.
device
,
non_blocking
=
True
)
MNI152_T1_2mm_brain_mask
=
MNI152_T1_2mm_brain_mask
.
cuda
(
self
.
device
,
non_blocking
=
True
)
MNI152_T1_2mm_brain_mask
=
MNI152_T1_2mm_brain_mask
.
cuda
(
self
.
device
,
non_blocking
=
True
)
y_hat
=
model
(
X
)
# Forward pass & Masking
...
...
utils/data_logging_utils.py
View file @
7c0603da
...
...
@@ -82,30 +82,6 @@ class LogWriter():
"{}/{}.log"
.
format
(
os
.
path
.
join
(
logs_directory
,
experiment_name
),
"console_logs"
))
self
.
logger
.
addHandler
(
file_handler
)
def
labels_generator
(
self
,
labels
):
""" Label Generator Function
This function processess an input array of labels.
Args:
labels (arr): Vector/Array of labels (if applicable)
Returns:
label_classes (list): List of processed labels
"""
label_classes
=
[]
for
label
in
labels
:
label_class
=
re
.
sub
(
r
'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))'
,
r
'\1 '
,
label
)
label_class
=
[
'
\n
'
.
join
(
wrap
(
element
,
40
))
for
element
in
label_class
]
label_classes
.
append
(
label_class
)
return
label_classes
def
log
(
self
,
message
):
"""Log function
...
...
@@ -131,7 +107,7 @@ class LogWriter():
print
(
"Loss for Iteration {} is: {}"
.
format
(
batch_index
,
loss_per_iteration
))
self
.
log_writer
[
'train'
].
add_scalar
(
'loss
/
iteration'
,
loss_per_iteration
,
iteration
)
'loss
/
iteration'
,
loss_per_iteration
,
iteration
)
def
loss_per_epoch
(
self
,
losses
,
phase
,
epoch
):
"""Log function
...
...
@@ -150,31 +126,31 @@ class LogWriter():
loss
=
np
.
mean
(
losses
)
print
(
"Loss for Epoch {} of {} is: {}"
.
format
(
epoch
,
phase
,
loss
))
self
.
log_writer
[
phase
].
add_scalar
(
'loss
/ iteration
'
,
loss
,
epoch
)
self
.
log_writer
[
phase
].
add_scalar
(
'loss
/epoch
'
,
loss
,
epoch
)
# Currently, no confusion matrix is required
# TODO: add a confusion matrix per epoch and confusion matrix plot functions if required
def
close
(
self
):
"""Close the log writer
def
dice_score_per_epoch
(
self
,
phase
,
outputs
,
correct_labels
,
epoch
):
"""
Function calculating dice score for each epoch
This function closes the two log writers.
"""
This function computes the dice score for each epoch.
self
.
log_writer
[
'train'
].
close
()
self
.
log_writer
[
'validation'
].
close
()
def
add_graph
(
self
,
model
):
"""Produces network graph
This function produces the network graph
NOTE: Currently, the function suffers from bugs and is not implemented.
Args:
phase (str): Current run mode or phase
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
epoch (int): Current epoch value
model (torch.nn.Module): Model to draw.
"""
print
(
"Dice Score is being calculated..."
,
end
=
''
,
flush
=
True
)
dice_score
=
evaluation
.
dice_score_calculator
(
outputs
,
correct_labels
,
self
.
number_of_classes
)
mean_dice_score
=
torch
.
mean
(
dice_score
)
self
.
plot_dice_score
(
dice_score
,
phase
,
plot_name
=
'dice_score_per_epoch'
,
title
=
'Dice Score'
,
epochs
=
epoch
)
print
(
"Dice Score calculated successfully"
)
return
mean_dice_score
.
item
()
self
.
log_writer
[
'train'
].
add_graph
(
model
)
# DEPRECATED / UNDEBUGGED FUNCTIONS
def
plot_dice_score
(
self
,
dice_score
,
phase
,
plot_name
,
title
,
epochs
=
None
):
"""Function plotting dice score for multiple epochs
...
...
@@ -208,7 +184,26 @@ class LogWriter():
else
:
self
.
log_writer
[
phase
].
add_figure
(
plot_name
+
'/'
+
phase
,
figure
)
# Currently, also no need for an evaluation box plot
def
dice_score_per_epoch
(
self
,
phase
,
outputs
,
correct_labels
,
epoch
):
"""Function calculating dice score for each epoch
This function computes the dice score for each epoch.
Args:
phase (str): Current run mode or phase
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
epoch (int): Current epoch value
"""
print
(
"Dice Score is being calculated..."
,
end
=
''
,
flush
=
True
)
dice_score
=
evaluation
.
dice_score_calculator
(
outputs
,
correct_labels
,
self
.
number_of_classes
)
mean_dice_score
=
torch
.
mean
(
dice_score
)
self
.
plot_dice_score
(
dice_score
,
phase
,
plot_name
=
'dice_score_per_epoch'
,
title
=
'Dice Score'
,
epochs
=
epoch
)
print
(
"Dice Score calculated successfully"
)
return
mean_dice_score
.
item
()
def
sample_image_per_epoch
(
self
,
prediction
,
ground_truth
,
phase
,
epoch
):
"""Function plotting mirrored images
...
...
@@ -240,11 +235,26 @@ class LogWriter():
print
(
"Sample Image successfully loaded!"
)
def
close
(
self
):
"""
Close the log writer
def
labels_generator
(
self
,
labels
):
"""
Label Generator Function
This function closes the two log writers.
This function processess an input array of labels.
Args:
labels (arr): Vector/Array of labels (if applicable)
Returns:
label_classes (list): List of processed labels
"""
self
.
log_writer
[
'train'
].
close
()
self
.
log_writer
[
'validation'
].
close
()
label_classes
=
[]
for
label
in
labels
:
label_class
=
re
.
sub
(
r
'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))'
,
r
'\1 '
,
label
)
label_class
=
[
'
\n
'
.
join
(
wrap
(
element
,
40
))
for
element
in
label_class
]
label_classes
.
append
(
label_class
)
return
label_classes
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment