Commit df4eeb88 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

updated imports, moved use_last_checkpoint call

parent ee991fd1
......@@ -19,7 +19,7 @@ import glob
from import Image
from datetime import datetime
from utils.losses import MSELoss
from utils.data_utils import create_folder
from utils.common_utils import create_folder
from utils.data_logging_utils import LogWriter
from utils.early_stopping import EarlyStopping
from torch.optim import lr_scheduler
......@@ -125,15 +125,15 @@ class Solver():
self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
self.early_stop = False
if use_last_checkpoint:
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(
self.save_model_directory = save_model_directory
self.final_model_output_file = final_model_output_file
if use_last_checkpoint:
def train(self, train_loader, validation_loader):
"""Training Function
......@@ -373,12 +373,12 @@ class Solver():
# We are not loading the model_name as we might want to pre-train a model and then use it.
self.model.load_state_dict = checkpoint['state_dict']
self.optimizer.load_state_dict = checkpoint['optimizer']
self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
for state in self.optimizer.state.values():
for key, value in state.items():
if torch.is_tensor(value):
state[key] =
self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
"Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment