Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
0a75ffb5
Commit
0a75ffb5
authored
Mar 30, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
added loging functionality to the train function
parent
05d17c56
Changes
1
Hide whitespace changes
Inline
Side-by-side
solver.py
View file @
0a75ffb5
...
@@ -18,7 +18,7 @@ import torch
...
@@ -18,7 +18,7 @@ import torch
from
datetime
import
datetime
from
datetime
import
datetime
from
utils.losses
import
MSELoss
from
utils.losses
import
MSELoss
from
utils.data_utils
import
create_folder
from
utils.data_utils
import
create_folder
from
utils.data_logging_utils
import
#BLA - need to write something first
from
utils.data_logging_utils
import
LogWriter
from
torch.optim
import
lr_scheduler
from
torch.optim
import
lr_scheduler
checkpoint_directory
=
'checkpoints'
checkpoint_directory
=
'checkpoints'
...
@@ -103,11 +103,17 @@ class Solver():
...
@@ -103,11 +103,17 @@ class Solver():
self
.
start_epoch
=
1
self
.
start_epoch
=
1
self
.
start_iteration
=
1
self
.
start_iteration
=
1
self
.
best_mean_score
=
0
self
.
best_mean_score
=
0
self
.
best_mean_epoch
=
0
self
.
best_mean_
score_
epoch
=
0
if
use_last_checkpoint
:
if
use_last_checkpoint
:
self
.
load_checkpoint
()
self
.
load_checkpoint
()
self
.
LogWriter
=
LogWriter
(
number_of_classes
=
number_of_classes
,
logs_directory
=
logs_directory
,
experiment_name
=
experiment_name
,
use_last_checkpoint
=
use_last_checkpoint
,
labels
=
labels
)
def
train
(
self
,
train_loader
,
test_loader
):
def
train
(
self
,
train_loader
,
test_loader
):
"""Training Function
"""Training Function
...
@@ -178,8 +184,7 @@ class Solver():
...
@@ -178,8 +184,7 @@ class Solver():
if
batch_index
%
self
.
loss_log_period
==
0
:
if
batch_index
%
self
.
loss_log_period
==
0
:
# TODO: NEED A FUNCTION that logs outputs for debugging!
self
.
LogWriter
.
loss_per_iteration
(
self
,
loss
.
item
(),
batch_index
,
iteration
)
# Here, I need it to log the loss, batch id and iteration number\
iteration
+=
1
iteration
+=
1
...
@@ -201,12 +206,25 @@ class Solver():
...
@@ -201,12 +206,25 @@ class Solver():
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output_array
,
y_array
=
torch
.
cat
(
outputs
),
torch
.
cat
(
y_values
)
output_array
,
y_array
=
torch
.
cat
(
outputs
),
torch
.
cat
(
y_values
)
# TODO - using log functions, record loss per epoch, maybe generated images per epoch, dice score and any other relevant metrics?
self
.
LogWriter
.
loss_per_epoch
(
losses
,
phase
,
epoch
)
dice_score_mean
=
self
.
LogWriter
.
dice_score_per_epoch
(
phase
,
output_array
,
y_array
,
epoch
)
if
phase
===
'test'
:
if
dice_score_mean
>
self
.
best_mean_score
:
self
.
best_mean_score
=
dice_score_mean
self
.
best_mean_score_epoch
=
epoch
index
=
np
.
random
.
choice
(
len
(
dataloaders
[
phase
].
dataset
.
X
),
size
=
3
,
replace
=
False
)
self
.
LogWriter
.
sample_image_per_epoch
(
prediction
=
model
.
predict
(
dataloaders
[
phase
].
dataset
.
X
[
index
],
self
.
device
)
ground_truth
=
dataloaders
[
phase
].
dataset
.
y
[
index
],
\
phase
=
phase
epoch
=
epoch
)
print
(
"Epoch {}/{} DONE!"
.
format
(
epoch
,
self
.
number_epochs
))
print
(
"Epoch {}/{} DONE!"
.
format
(
epoch
,
self
.
number_epochs
))
self
.
save_checkpoint
()
# TODO - write function and save the checkpoint!
self
.
save_checkpoint
()
# TODO - write function and save the checkpoint!
self
.
LogWriter
.
close
()
print
(
'----------------------------------------'
)
print
(
'----------------------------------------'
)
print
(
'TRAINING IS COMPLETE!'
)
print
(
'TRAINING IS COMPLETE!'
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment