Commit f63736f4 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

finished building constructor

parent 0c62c469
......@@ -15,6 +15,9 @@ To use this module, import it and instantiate is as you wish:
import os
import numpy as np
import torch
from utils.losses import MSELoss
from utils.data_utils import create_folder
from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'
......@@ -57,7 +60,7 @@ class Solver():
experiment_name,
optimizer = torch.optim.Adam,
optimizer_arguments = {},
loss_function = loss_function, # Need to define
loss_function = MSELoss(),
model_name = 'BrainMapper',
labels = None,
number_epochs = 10,
......@@ -71,23 +74,53 @@ class Solver():
self.model = model
self.device = device
self.optimizer = optimizer(model.parameters(), **optimizer_arguments)
if torch.cuda.is_available():
self.loss_function = loss_function.cuda(device)
else:
self.loss_function = loss_function
self.model_name = model_name
self.labels = labels
self.number_epochs = number_epochs
self.loss_log_period = loss_log_period
# We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
step_size = learning_rate_scheduler_step_size,
gamma= learning_rate_scheduler_gamma)
self.use_last_checkpoint = use_last_checkpoint
experiment_directory_path = os.join.path(experiment_directory, experiment_name)
self.experiment_directory_path = experiment_directory_path
create_folder(experiment_directory_path)
create_folder(os.join.path(experiment_directory_path, checkpoint_directory))
self.start_epoch = 1
self.start_iteration = 1
self.best_mean_score = 0
self.best_mean_epoch = 0
if use_last_checkpoint:
self.load_checkpoint()
pass
def train():
def train(self):
pass
def save_model():
def save_model(self):
pass
def save_checkpoint():
def save_checkpoint(self):
pass
def load_checkpoint():
def load_checkpoint(self):
pass
def _load_checkpoint_file():
def _load_checkpoint_file(self):
# Name is private = can't be called outisde of this module
pass
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment