Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
f63736f4
Commit
f63736f4
authored
Mar 28, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
finished building constructor
parent
0c62c469
Changes
1
Hide whitespace changes
Inline
Side-by-side
solver.py
View file @
f63736f4
...
...
@@ -15,6 +15,9 @@ To use this module, import it and instantiate is as you wish:
import
os
import
numpy
as
np
import
torch
from
utils.losses
import
MSELoss
from
utils.data_utils
import
create_folder
from
torch.optim
import
lr_scheduler
checkpoint_directory
=
'checkpoints'
checkpoint_extension
=
'path.tar'
...
...
@@ -57,7 +60,7 @@ class Solver():
experiment_name
,
optimizer
=
torch
.
optim
.
Adam
,
optimizer_arguments
=
{},
loss_function
=
loss_function
,
# Need to define
loss_function
=
MSELoss
(),
model_name
=
'BrainMapper'
,
labels
=
None
,
number_epochs
=
10
,
...
...
@@ -71,23 +74,53 @@ class Solver():
self
.
model
=
model
self
.
device
=
device
self
.
optimizer
=
optimizer
(
model
.
parameters
(),
**
optimizer_arguments
)
if
torch
.
cuda
.
is_available
():
self
.
loss_function
=
loss_function
.
cuda
(
device
)
else
:
self
.
loss_function
=
loss_function
self
.
model_name
=
model_name
self
.
labels
=
labels
self
.
number_epochs
=
number_epochs
self
.
loss_log_period
=
loss_log_period
# We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
self
.
learning_rate_scheduler
=
lr_scheduler
.
StepLR
(
self
.
optimizer
,
step_size
=
learning_rate_scheduler_step_size
,
gamma
=
learning_rate_scheduler_gamma
)
self
.
use_last_checkpoint
=
use_last_checkpoint
experiment_directory_path
=
os
.
join
.
path
(
experiment_directory
,
experiment_name
)
self
.
experiment_directory_path
=
experiment_directory_path
create_folder
(
experiment_directory_path
)
create_folder
(
os
.
join
.
path
(
experiment_directory_path
,
checkpoint_directory
))
self
.
start_epoch
=
1
self
.
start_iteration
=
1
self
.
best_mean_score
=
0
self
.
best_mean_epoch
=
0
if
use_last_checkpoint
:
self
.
load_checkpoint
()
pass
def
train
():
def
train
(
self
):
pass
def
save_model
():
def
save_model
(
self
):
pass
def
save_checkpoint
():
def
save_checkpoint
(
self
):
pass
def
load_checkpoint
():
def
load_checkpoint
(
self
):
pass
def
_load_checkpoint_file
():
def
_load_checkpoint_file
(
self
):
# Name is private = can't be called outisde of this module
pass
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment