Commit df4eeb88 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

updated imports, moved use_last_checkpoint call

parent ee991fd1
...@@ -19,7 +19,7 @@ import glob ...@@ -19,7 +19,7 @@ import glob
from fsl.data.image import Image from fsl.data.image import Image
from datetime import datetime from datetime import datetime
from utils.losses import MSELoss from utils.losses import MSELoss
from utils.data_utils import create_folder from utils.common_utils import create_folder
from utils.data_logging_utils import LogWriter from utils.data_logging_utils import LogWriter
from utils.early_stopping import EarlyStopping from utils.early_stopping import EarlyStopping
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
...@@ -125,15 +125,15 @@ class Solver(): ...@@ -125,15 +125,15 @@ class Solver():
self.EarlyStopping = EarlyStopping(patience=10, min_delta=0) self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
self.early_stop = False self.early_stop = False
if use_last_checkpoint:
self.load_checkpoint()
self.MNI152_T1_2mm_brain_mask = torch.from_numpy( self.MNI152_T1_2mm_brain_mask = torch.from_numpy(
Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data) Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
self.save_model_directory = save_model_directory self.save_model_directory = save_model_directory
self.final_model_output_file = final_model_output_file self.final_model_output_file = final_model_output_file
if use_last_checkpoint:
self.load_checkpoint()
def train(self, train_loader, validation_loader): def train(self, train_loader, validation_loader):
"""Training Function """Training Function
...@@ -373,12 +373,12 @@ class Solver(): ...@@ -373,12 +373,12 @@ class Solver():
# We are not loading the model_name as we might want to pre-train a model and then use it. # We are not loading the model_name as we might want to pre-train a model and then use it.
self.model.load_state_dict = checkpoint['state_dict'] self.model.load_state_dict = checkpoint['state_dict']
self.optimizer.load_state_dict = checkpoint['optimizer'] self.optimizer.load_state_dict = checkpoint['optimizer']
self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
for state in self.optimizer.state.values(): for state in self.optimizer.state.values():
for key, value in state.items(): for key, value in state.items():
if torch.is_tensor(value): if torch.is_tensor(value):
state[key] = value.to(self.device) state[key] = value.to(self.device)
self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
self.LogWriter.log( self.LogWriter.log(
"Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch'])) "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment