Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
7eacf841
Commit
7eacf841
authored
Apr 04, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
updated docstrings, pep8 formated, fixed bugs
parent
c442542d
Changes
9
Expand all
Show whitespace changes
Inline
Side-by-side
BrainMapperUNet.py
View file @
7eacf841
"""Brain Mapper U-Net Architecture
Description:
-------------
This folder contains the Pytorch implementation of the core U-net architecture.
This arcitecture predicts functional connectivity rsfMRI from structural connectivity information from dMRI.
Usage
-------------
To use this module, import it and instantiate is as you wish:
This folder contains the Pytorch implementation of the core U-net architecture.
This arcitecture predicts functional connectivity rsfMRI from structural connectivity information from dMRI.
Usage:
To use this module, import it and instantiate is as you wish:
from BrainMapperUNet import BrainMapperUNet
deep_learning_model = BrainMapperUnet(parameters)
...
...
@@ -19,6 +19,7 @@ import torch
import
torch.nn
as
nn
import
utils.modules
as
modules
class
BrainMapperUNet
(
nn
.
Module
):
"""Architecture class BrainMapper U-net.
...
...
@@ -42,9 +43,6 @@ class BrainMapperUNet(nn.Module):
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
"""
def
__init__
(
self
,
parameters
):
...
...
@@ -80,23 +78,26 @@ class BrainMapperUNet(nn.Module):
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
"""
Y_encoder_1
,
Y_np1
,
pool_indices1
=
self
.
encoderBlock1
.
forward
(
X
)
Y_encoder_2
,
Y_np2
,
pool_indices2
=
self
.
encoderBlock2
.
forward
(
Y_encoder_1
)
Y_encoder_3
,
Y_np3
,
pool_indices3
=
self
.
encoderBlock3
.
forward
(
Y_encoder_2
)
Y_encoder_4
,
Y_np4
,
pool_indices4
=
self
.
encoderBlock4
.
forward
(
Y_encoder_3
)
Y_encoder_2
,
Y_np2
,
pool_indices2
=
self
.
encoderBlock2
.
forward
(
Y_encoder_1
)
Y_encoder_3
,
Y_np3
,
pool_indices3
=
self
.
encoderBlock3
.
forward
(
Y_encoder_2
)
Y_encoder_4
,
Y_np4
,
pool_indices4
=
self
.
encoderBlock4
.
forward
(
Y_encoder_3
)
Y_bottleNeck
=
self
.
bottleneck
.
forward
(
Y_encoder_4
)
Y_decoder_1
=
self
.
decoderBlock1
.
forward
(
Y_bottleNeck
,
Y_np4
,
pool_indices4
)
Y_decoder_2
=
self
.
decoderBlock2
.
forward
(
Y_decoder_1
,
Y_np3
,
pool_indices3
)
Y_decoder_3
=
self
.
decoderBlock3
.
forward
(
Y_decoder_2
,
Y_np2
,
pool_indices2
)
Y_decoder_4
=
self
.
decoderBlock4
.
forwrad
(
Y_decoder_3
,
Y_np1
,
pool_indices1
)
Y_decoder_1
=
self
.
decoderBlock1
.
forward
(
Y_bottleNeck
,
Y_np4
,
pool_indices4
)
Y_decoder_2
=
self
.
decoderBlock2
.
forward
(
Y_decoder_1
,
Y_np3
,
pool_indices3
)
Y_decoder_3
=
self
.
decoderBlock3
.
forward
(
Y_decoder_2
,
Y_np2
,
pool_indices2
)
Y_decoder_4
=
self
.
decoderBlock4
.
forwrad
(
Y_decoder_3
,
Y_np1
,
pool_indices1
)
probability_map
=
self
.
classifier
.
forward
(
Y_decoder_4
)
...
...
@@ -110,12 +111,6 @@ class BrainMapperUNet(nn.Module):
Args:
path (str): Path string
Returns:
None
Raises:
None
"""
print
(
"Saving Model... {}"
.
format
(
path
))
...
...
@@ -127,19 +122,13 @@ class BrainMapperUNet(nn.Module):
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Args:
None
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
Raises:
None
"""
return
next
(
self
.
parameters
()).
is_cuda
def
predict
(
self
,
X
,
device
=
0
):
def
predict
(
self
,
X
,
device
=
0
):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
...
...
@@ -151,9 +140,6 @@ class BrainMapperUNet(nn.Module):
Returns:
prediction (ndarray): predicted output after training
Raises:
None
"""
self
.
eval
()
# PyToch module setting network to evaluation mode
...
...
@@ -170,7 +156,8 @@ class BrainMapperUNet(nn.Module):
_
,
idx
=
torch
.
max
(
output
,
1
)
idx
=
idx
.
data
.
cpu
().
numpy
()
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx
=
idx
.
data
.
cpu
().
numpy
()
prediction
=
np
.
squeeze
(
idx
)
...
...
run.py
View file @
7eacf841
This diff is collapsed.
Click to expand it.
settings.py
View file @
7eacf841
...
...
@@ -64,9 +64,8 @@ def _parse_values(configurator):
settings_dictionary
=
{}
for
section
in
configurator
.
sections
():
settings_dictionary
[
section
]
=
{}
for
key
,
value
in
configurator
[
section
].
items
()
for
key
,
value
in
configurator
[
section
].
items
()
:
# Safely evaluate an expression node or a Unicode or Latin-1 encoded string containing a Python expression
settings_dictionary
[
section
][
key
]
=
ast
.
literal_eval
(
value
)
return
settings_dictionary
return
settings_dictionary
setup.py
View file @
7eacf841
...
...
@@ -16,6 +16,5 @@ setup(
'torch'
,
'h5py'
,
'tensorboardX'
,
],
)
solver.py
View file @
7eacf841
"""Brain Mapper U-Net Solver
Description:
-------------
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
Usage
-------------
To use this module, import it and instantiate is as you wish:
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
from solver import Solver
Usage:
To use this module, import it and instantiate is as you wish:
from solver import Solver
"""
import
os
...
...
@@ -26,6 +25,7 @@ from torch.optim import lr_scheduler
checkpoint_directory
=
'checkpoints'
checkpoint_extension
=
'path.tar'
class
Solver
():
"""Solver class for the BrainMapper U-net.
...
...
@@ -52,8 +52,6 @@ class Solver():
Returns:
trained model(?) - working on this!
Raises:
None
"""
def
__init__
(
self
,
...
...
@@ -61,18 +59,18 @@ class Solver():
device
,
number_of_classes
,
experiment_name
,
optimizer
=
torch
.
optim
.
Adam
,
optimizer_arguments
=
{},
loss_function
=
MSELoss
(),
model_name
=
'BrainMapper'
,
labels
=
None
,
number_epochs
=
10
,
loss_log_period
=
5
,
learning_rate_scheduler_step_size
=
5
,
learning_rate_scheduler_gamma
=
0.5
,
use_last_checkpoint
=
True
,
experiment_directory
=
'experiments'
,
logs_directory
=
'logs'
optimizer
=
torch
.
optim
.
Adam
,
optimizer_arguments
=
{},
loss_function
=
MSELoss
(),
model_name
=
'BrainMapper'
,
labels
=
None
,
number_epochs
=
10
,
loss_log_period
=
5
,
learning_rate_scheduler_step_size
=
5
,
learning_rate_scheduler_gamma
=
0.5
,
use_last_checkpoint
=
True
,
experiment_directory
=
'experiments'
,
logs_directory
=
'logs'
):
self
.
model
=
model
...
...
@@ -91,16 +89,18 @@ class Solver():
# We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
self
.
learning_rate_scheduler
=
lr_scheduler
.
StepLR
(
self
.
optimizer
,
step_size
=
learning_rate_scheduler_step_size
,
gamma
=
learning_rate_scheduler_gamma
)
step_size
=
learning_rate_scheduler_step_size
,
gamma
=
learning_rate_scheduler_gamma
)
self
.
use_last_checkpoint
=
use_last_checkpoint
experiment_directory_path
=
os
.
join
.
path
(
experiment_directory
,
experiment_name
)
experiment_directory_path
=
os
.
join
.
path
(
experiment_directory
,
experiment_name
)
self
.
experiment_directory_path
=
experiment_directory_path
create_folder
(
experiment_directory_path
)
create_folder
(
os
.
join
.
path
(
experiment_directory_path
,
checkpoint_directory
))
create_folder
(
os
.
join
.
path
(
experiment_directory_path
,
checkpoint_directory
))
self
.
start_epoch
=
1
self
.
start_iteration
=
1
...
...
@@ -110,12 +110,11 @@ class Solver():
if
use_last_checkpoint
:
self
.
load_checkpoint
()
self
.
LogWriter
=
LogWriter
(
number_of_classes
=
number_of_classes
,
logs_directory
=
logs_directory
,
experiment_name
=
experiment_name
,
use_last_checkpoint
=
use_last_checkpoint
,
labels
=
labels
)
self
.
LogWriter
=
LogWriter
(
number_of_classes
=
number_of_classes
,
logs_directory
=
logs_directory
,
experiment_name
=
experiment_name
,
use_last_checkpoint
=
use_last_checkpoint
,
labels
=
labels
)
def
train
(
self
,
train_loader
,
test_loader
):
"""Training Function
...
...
@@ -127,10 +126,7 @@ class Solver():
test_loader (class): Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader)
Returns:
None: trained model
Raises:
None
trained model
"""
model
,
optimizer
,
learning_rate_scheduler
=
self
.
model
,
self
.
optimizer
,
self
.
learning_rate_scheduler
...
...
@@ -172,8 +168,8 @@ class Solver():
y
=
sampled_batch
[
1
].
type
(
torch
.
LongTensor
)
if
model
.
is_cuda
():
X
=
X
.
cuda
(
self
.
device
,
non_blocking
=
True
)
y
=
y
.
cuda
(
self
.
device
,
non_blocking
=
True
)
X
=
X
.
cuda
(
self
.
device
,
non_blocking
=
True
)
y
=
y
.
cuda
(
self
.
device
,
non_blocking
=
True
)
y_hat
=
model
(
X
)
# Forward pass
...
...
@@ -186,7 +182,8 @@ class Solver():
if
batch_index
%
self
.
loss_log_period
==
0
:
self
.
LogWriter
.
loss_per_iteration
(
loss
.
item
(),
batch_index
,
iteration
)
self
.
LogWriter
.
loss_per_iteration
(
loss
.
item
(),
batch_index
,
iteration
)
iteration
+=
1
...
...
@@ -206,25 +203,27 @@ class Solver():
print
(
"100%"
,
flush
=
True
)
with
torch
.
no_grad
():
output_array
,
y_array
=
torch
.
cat
(
outputs
),
torch
.
cat
(
y_values
)
output_array
,
y_array
=
torch
.
cat
(
outputs
),
torch
.
cat
(
y_values
)
self
.
LogWriter
.
loss_per_epoch
(
losses
,
phase
,
epoch
)
dice_score_mean
=
self
.
LogWriter
.
dice_score_per_epoch
(
phase
,
output_array
,
y_array
,
epoch
)
dice_score_mean
=
self
.
LogWriter
.
dice_score_per_epoch
(
phase
,
output_array
,
y_array
,
epoch
)
if
phase
==
'test'
:
if
dice_score_mean
>
self
.
best_mean_score
:
self
.
best_mean_score
=
dice_score_mean
self
.
best_mean_score_epoch
=
epoch
index
=
np
.
random
.
choice
(
len
(
dataloaders
[
phase
].
dataset
.
X
),
size
=
3
,
replace
=
False
)
self
.
LogWriter
.
sample_image_per_epoch
(
prediction
=
model
.
predict
(
dataloaders
[
phase
].
dataset
.
X
[
index
],
self
.
device
),
ground_truth
=
dataloaders
[
phase
].
dataset
.
y
[
index
],
phase
=
phase
,
epoch
=
epoch
)
index
=
np
.
random
.
choice
(
len
(
dataloaders
[
phase
].
dataset
.
X
),
size
=
3
,
replace
=
False
)
self
.
LogWriter
.
sample_image_per_epoch
(
prediction
=
model
.
predict
(
dataloaders
[
phase
].
dataset
.
X
[
index
],
self
.
device
),
ground_truth
=
dataloaders
[
phase
].
dataset
.
y
[
index
],
phase
=
phase
,
epoch
=
epoch
)
print
(
"Epoch {}/{} DONE!"
.
format
(
epoch
,
self
.
number_epochs
))
self
.
save_checkpoint
(
state
=
{
'epoch'
:
epoch
+
1
,
'start_iteration'
:
iteration
+
1
,
'arch'
:
self
.
model_name
,
...
...
@@ -232,7 +231,8 @@ class Solver():
'optimizer'
:
optimizer
.
state_dict
(),
'scheduler'
:
learning_rate_scheduler
.
state_dict
()
},
filename
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
filename
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
)
self
.
LogWriter
.
close
()
...
...
@@ -252,48 +252,38 @@ class Solver():
Args:
state (dict): Dictionary of all the relevant model components
Returns:
None
Raises:
None
"""
torch
.
save
(
state
,
filename
)
def
load_checkpoint
(
self
,
epoch
=
None
):
def
load_checkpoint
(
self
,
epoch
=
None
):
"""General Checkpoint Loader
This function loads a previous checkpoint for inference and/or resuming training
Args:
epoch (int): Current epoch value
Returns:
None
Raises:
None
"""
if
epoch
is
None
:
checkpoint_file_path
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
checkpoint_file_path
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
self
.
_checkpoint_reader
(
checkpoint_file_path
)
else
:
universal_path
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
checkpoint_directory
,
'*.'
+
checkpoint_extension
)
universal_path
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
checkpoint_directory
,
'*.'
+
checkpoint_extension
)
files_in_universal_path
=
glob
.
glob
(
universal_path
)
# We will sort through all the files in path to see which one is most recent
if
len
(
files_in_universal_path
)
>
0
:
checkpoint_file_path
=
max
(
files_in_universal_path
,
key
=
os
.
path
.
getatime
)
checkpoint_file_path
=
max
(
files_in_universal_path
,
key
=
os
.
path
.
getatime
)
self
.
_checkpoint_reader
(
checkpoint_file_path
)
else
:
self
.
LogWriter
.
log
(
"No Checkpoint found at {}"
.
format
(
os
.
path
.
join
(
self
.
experiment_directory_path
,
checkpoint_directory
)))
self
.
LogWriter
.
log
(
"No Checkpoint found at {}"
.
format
(
os
.
path
.
join
(
self
.
experiment_directory_path
,
checkpoint_directory
)))
def
_checkpoint_reader
(
self
,
checkpoint_file_path
):
"""Checkpoint Reader
...
...
@@ -302,15 +292,10 @@ class Solver():
Args:
checkpoint_file_path (str): path to checkpoint file
Returns:
None
Raises:
None
"""
self
.
LogWriter
.
log
(
"Loading Checkpoint {}"
.
format
(
checkpoint_file_path
))
self
.
LogWriter
.
log
(
"Loading Checkpoint {}"
.
format
(
checkpoint_file_path
))
checkpoint
=
torch
.
load
(
checkpoint_file_path
)
self
.
start_epoch
=
checkpoint
[
'epoch'
]
...
...
@@ -325,4 +310,5 @@ class Solver():
if
torch
.
is_tensor
(
value
):
state
[
key
]
=
value
.
to
(
self
.
device
)
self
.
LogWriter
.
log
(
"Checkpoint Loaded {} - epoch {}"
.
format
(
checkpoint_file_path
,
checkpoint
[
'epoch'
]))
\ No newline at end of file
self
.
LogWriter
.
log
(
"Checkpoint Loaded {} - epoch {}"
.
format
(
checkpoint_file_path
,
checkpoint
[
'epoch'
]))
utils/data_evaluation_utils.py
View file @
7eacf841
This diff is collapsed.
Click to expand it.
utils/data_logging_utils.py
View file @
7eacf841
"""Data Logging Functions
Description:
-------------
This folder contains several functions which, either on their own or included in larger pieces of software, perform data logging tasks.
Usage
-------------
To use content from this folder, import the functions and instantiate them as you wish to use them:
This folder contains several functions which, either on their own or included in larger pieces of software, perform data logging tasks.
Usage:
To use content from this folder, import the functions and instantiate them as you wish to use them:
from utils.data_logging_utils import function_name
...
...
@@ -30,6 +30,7 @@ import utils.data_evaluation_utils as evaluation
plt
.
axis
(
'scaled'
)
class
LogWriter
():
"""Log Writer class for the BrainMapper U-net.
...
...
@@ -44,19 +45,15 @@ class LogWriter():
use_last_checkpoint (bool): Flag for loading the previous checkpoint
labels (arr): Vector/Array of labels (if applicable)
confusion_matrix_cmap (class): Colour Map to be used for the Conusion Matrix
Returns:
None
Raises:
None
"""
def
__init__
(
self
,
number_of_classes
,
logs_directory
,
experiment_name
,
use_last_checkpoint
=
False
,
labels
=
None
,
confusion_matrix_cmap
=
plt
.
cm
.
Blues
):
def
__init__
(
self
,
number_of_classes
,
logs_directory
,
experiment_name
,
use_last_checkpoint
=
False
,
labels
=
None
,
confusion_matrix_cmap
=
plt
.
cm
.
Blues
):
self
.
number_of_classes
=
number_of_classes
training_logs_directory
=
os
.
path
.
join
(
logs_directory
,
experiment_name
,
"train"
)
testing_logs_directory
=
os
.
path
.
join
(
logs_directory
,
experiment_name
,
"test"
)
training_logs_directory
=
os
.
path
.
join
(
logs_directory
,
experiment_name
,
"train"
)
testing_logs_directory
=
os
.
path
.
join
(
logs_directory
,
experiment_name
,
"test"
)
# If the logs directory exist, we clear their contents to allow new logs to be created
if
not
use_last_checkpoint
:
...
...
@@ -66,8 +63,8 @@ class LogWriter():
shutil
.
rmtree
(
testing_logs_directory
)
self
.
log_writer
=
{
'train'
:
SummaryWriter
(
logdir
=
training_logs_directory
),
'test:'
:
SummaryWriter
(
logdir
=
testing_logs_directory
)
'train'
:
SummaryWriter
(
logdir
=
training_logs_directory
),
'test:'
:
SummaryWriter
(
logdir
=
testing_logs_directory
)
}
self
.
confusion_matrix_color_map
=
confusion_matrix_cmap
...
...
@@ -77,7 +74,8 @@ class LogWriter():
self
.
labels
=
self
.
labels_generator
(
labels
)
self
.
logger
=
logging
.
getLogger
()
file_handler
=
logging
.
FileHandler
(
"{}/{}.log"
.
format
(
os
.
path
.
join
(
logs_directory
,
experiment_name
),
"console_logs"
))
file_handler
=
logging
.
FileHandler
(
"{}/{}.log"
.
format
(
os
.
path
.
join
(
logs_directory
,
experiment_name
),
"console_logs"
))
self
.
logger
.
addHandler
(
file_handler
)
def
labels_generator
(
self
,
labels
):
...
...
@@ -90,17 +88,16 @@ class LogWriter():
Returns:
label_classes (list): List of processed labels
Raises:
None
"""
label_classes
=
[]
for
label
in
labels
:
label_class
=
re
.
sub
(
r
'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))'
,
r
'\1 '
,
label
)
label_class
=
[
'
\n
'
.
join
(
wrap
(
element
,
40
))
for
element
in
label_class
]
label_class
=
re
.
sub
(
r
'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))'
,
r
'\1 '
,
label
)
label_class
=
[
'
\n
'
.
join
(
wrap
(
element
,
40
))
for
element
in
label_class
]
label_classes
.
append
(
label_class
)
return
label_classes
...
...
@@ -112,15 +109,9 @@ class LogWriter():
Args:
message (str): Message to be logged
Returns:
None
Raises:
None
"""
self
.
logger
.
info
(
msg
=
message
)
self
.
logger
.
info
(
msg
=
message
)
def
loss_per_iteration
(
self
,
loss_per_iteration
,
batch_index
,
iteration
):
"""Log of loss / iteration
...
...
@@ -131,16 +122,12 @@ class LogWriter():
loss_per_iteration (torch.tensor): Value of loss for every iteration step
batch_index (int): Index of current batch
iteration (int): Current iteration value
Returns:
None
Raises:
None
"""
print
(
"Loss for Iteration {} is: {}"
.
format
(
batch_index
,
loss_per_iteration
))
self
.
log_writer
[
'train'
].
add_scalar
(
'loss / iteration'
,
loss_per_iteration
,
iteration
)
print
(
"Loss for Iteration {} is: {}"
.
format
(
batch_index
,
loss_per_iteration
))
self
.
log_writer
[
'train'
].
add_scalar
(
'loss / iteration'
,
loss_per_iteration
,
iteration
)
def
loss_per_epoch
(
self
,
losses
,
phase
,
epoch
):
"""Log function
...
...
@@ -151,12 +138,6 @@ class LogWriter():
losses (list): Values of all the losses recorded during the training epoch
phase (str): Current run mode or phase
epoch (int): Current epoch value
Returns:
None
Raises:
None
"""
if
phase
==
'train'
:
...
...
@@ -180,18 +161,14 @@ class LogWriter():
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
epoch (int): Current epoch value
Returns:
mean_dice_score (torch.tensor): Mean dice score value
Raises
None
"""
print
(
"Dice Score is being calculated..."
,
end
=
''
,
flush
=
True
)
dice_score
=
evaluation
.
dice_score_calculator
(
outputs
,
correct_labels
,
self
.
number_of_classes
)
print
(
"Dice Score is being calculated..."
,
end
=
''
,
flush
=
True
)
dice_score
=
evaluation
.
dice_score_calculator
(
outputs
,
correct_labels
,
self
.
number_of_classes
)
mean_dice_score
=
torch
.
mean
(
dice_score
)
self
.
plot_dice_score
(
dice_score
,
phase
,
plot_name
=
'dice_score_per_epoch'
,
title
=
'Dice Score'
,
epochs
=
epoch
)
self
.
plot_dice_score
(
dice_score
,
phase
,
plot_name
=
'dice_score_per_epoch'
,
title
=
'Dice Score'
,
epochs
=
epoch
)
print
(
"Dice Score calculated successfully"
)
return
mean_dice_score
.
item
()
...
...
@@ -206,12 +183,6 @@ class LogWriter():
plot_name (str): Caption name for later refference
title (str): Plot title
epoch (int): Current epoch value
Returns:
None
Raises
None
"""
figure
=
matplotlib
.
figure
.
Figure
()
# Might add some arguments here later
...
...
@@ -224,7 +195,8 @@ class LogWriter():
ax
.
xaxis
.
tick_bottom
()
if
epochs
:
self
.
log_writer
[
phase
].
add_figure
(
plot_name
+
'/'
+
phase
,
figure
,
global_step
=
epochs
)
self
.
log_writer
[
phase
].
add_figure
(
plot_name
+
'/'
+
phase
,
figure
,
global_step
=
epochs
)
else
:
self
.
log_writer
[
phase
].
add_figure
(
plot_name
+
'/'
+
phase
,
figure
)
...
...
@@ -240,16 +212,10 @@ class LogWriter():
ground_truth (torch.tensor): Labelled ground truth image
phase (str): Current run mode or phase
epoch (int): Current epoch value
Returns:
None
Raises
None
"""
print
(
"Sample Image is being loaded..."
,
end
=
''
,
flush
=
True
)
figure
,
ax
=
plt
.
subplots
(
nrows
=
len
(
prediction
),
ncols
=
2
)
print
(
"Sample Image is being loaded..."
,
end
=
''
,
flush
=
True
)
figure
,
ax
=
plt
.
subplots
(
nrows
=
len
(
prediction
),
ncols
=
2
)
for
i
in
range
(
len
(
prediction
)):
ax
[
i
][
0
].
imshow
(
prediction
[
i
])
...
...
@@ -261,7 +227,8 @@ class LogWriter():
ax
[
i
][
1
].
axis
(
'off'
)
figure
.
set_tight_layout
()
self
.
log_writer
[
phase
].
add_figure
(
'sample_prediction/'
+
phase
,
figure
,
epoch
)
self
.
log_writer
[
phase
].
add_figure
(
'sample_prediction/'
+
phase
,
figure
,
epoch
)
print
(
"Sample Image successfully loaded!"
)
...
...
@@ -269,15 +236,6 @@ class LogWriter():