Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
af29f33a
Commit
af29f33a
authored
Mar 28, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
worte the training function - need to add logger!
parent
16642040
Changes
1
Show whitespace changes
Inline
Side-by-side
solver.py
View file @
af29f33a
...
...
@@ -15,8 +15,10 @@ To use this module, import it and instantiate is as you wish:
import
os
import
numpy
as
np
import
torch
from
datetime
import
datetime
from
utils.losses
import
MSELoss
from
utils.data_utils
import
create_folder
from
utils.data_logging_utils
import
#BLA - need to write something first
from
torch.optim
import
lr_scheduler
checkpoint_directory
=
'checkpoints'
...
...
@@ -45,7 +47,6 @@ class Solver():
experiment_directory (str): Experiment output directory name
logs_directory (str): Directory for outputing training logs
Returns:
trained model(?) - working on this!
...
...
@@ -108,8 +109,114 @@ class Solver():
self
.
load_checkpoint
()
def
train
(
self
):
pass
def
train
(
self
,
train_loader
,
test_loader
):
"""Training Function
This function trains a given model using the provided training data.
Args:
train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
test_loader (class): Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader)
Returns:
None: trained model
Raises:
None
"""
model
,
optimizer
,
learning_rate_scheduler
=
self
.
model
,
self
.
optimizer
,
self
.
learning_rate_scheduler
dataloaders
=
{
'train'
:
train_loader
,
'test'
:
test_loader
}
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
# clear memory
model
.
cuda
(
self
.
device
)
# Moving the model to GPU
print
(
'****************************************************************'
)
print
(
'TRAINING IS STARTING!'
)
print
(
'====================='
)
print
(
'Model Name: {}'
.
format
(
self
.
model_name
))
print
(
'Device Type: {}'
.
format
(
torch
.
cuda
.
get_device_name
(
self
.
device
)))
start_time
=
datetime
.
now
()
print
(
'Started At: {}'
.
format
(
start_time
))
print
(
'----------------------------------------'
)
iteration
=
self
.
start_iteration
for
epoch
in
range
(
self
.
start_epoch
,
self
.
number_epochs
+
1
):
print
(
"Epoch {}/{}"
.
format
(
epoch
,
self
.
number_epochs
))
for
phase
in
[
'train'
,
'test'
]:
print
(
'-> Phase: {}'
.
format
(
phase
))
losses
=
[]
outputs
=
[]
y_values
=
[]
if
phase
==
'train'
:
model
.
train
()
learning_rate_scheduler
.
step
()
else
:
model
.
eval
()
for
batch_index
,
sampled_batch
in
enumerate
(
dataloaders
[
phase
]):
X
=
sampled_batch
[
0
].
type
(
torch
.
FloatTensor
)
y
=
sampled_batch
[
1
].
type
(
torch
.
LondTensor
)
if
model
.
is_cuda
():
X
=
X
.
cuda
(
self
.
device
,
non_blocking
=
True
)
y
=
y
.
cuda
(
self
.
device
,
non_blocking
=
True
)
y_hat
=
model
(
X
)
# Forward pass
loss
=
self
.
loss_function
(
y_hat
,
y
)
# Loss computation
if
phase
==
'train'
:
optimizer
.
zero_grad
()
# Zero the parameter gradients
loss
.
backward
()
# Backward propagation
optimizer
.
step
()
if
batch_index
%
self
.
loss_log_period
==
0
:
# TODO: NEED A FUNCTION that logs outputs for debugging!
# Here, I need it to log the loss, batch id and iteration number\
iteration
+=
1
losses
.
append
(
loss
.
item
())
outputs
.
append
(
torch
.
max
(
y_hat
,
dim
=
1
)[
1
].
cpu
())
y_values
.
append
(
y
.
cpu
())
# Clear the memory
del
X
,
y
,
y_hat
,
loss
torch
.
cuda
.
empty_cache
()
if
phase
==
'test'
:
if
batch_index
!=
len
(
dataloaders
[
phase
])
-
1
:
print
(
"#"
,
end
=
''
,
flush
=
True
)
else
:
print
(
"100%"
,
flush
=
True
)
with
torch
.
no_grad
():
output_array
,
y_array
=
torch
.
cat
(
outputs
),
torch
.
cat
(
y_values
)
# TODO - using log functions, record loss per epoch, maybe generated images per epoch, dice score and any other relevant metrics?
print
(
"Epoch {}/{} DONE!"
.
format
(
epoch
,
self
.
number_epochs
))
self
.
save_checkpoint
()
# TODO - write function and save the checkpoint!
print
(
'----------------------------------------'
)
print
(
'TRAINING IS COMPLETE!'
)
print
(
'====================='
)
end_time
=
datetime
.
now
()
print
(
'Completed At: {}'
.
format
(
end_time
))
print
(
'Training Duration: {}'
.
format
(
end_time
-
start_time
))
print
(
'****************************************************************'
)
# TODO: MAKE SURE any log writer function is closed!
def
save_model
(
self
):
pass
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment