Commit 12db841c authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

code cleaning & fixing semantic mistakes

parent 9f21c7b3
......@@ -169,7 +169,7 @@ class Solver():
for batch_index, sampled_batch in enumerate(dataloaders[phase]):
X = sampled_batch[0].type(torch.FloatTensor)
y = sampled_batch[1].type(torch.LondTensor)
y = sampled_batch[1].type(torch.LongTensor)
if model.is_cuda():
X = X.cuda(self.device, non_blocking= True)
......@@ -186,7 +186,7 @@ class Solver():
if batch_index % self.loss_log_period == 0:
self.LogWriter.loss_per_iteration(self, loss.item(), batch_index, iteration)
self.LogWriter.loss_per_iteration(loss.item(), batch_index, iteration)
iteration += 1
......@@ -219,7 +219,7 @@ class Solver():
index = np.random.choice(len(dataloaders[phase].dataset.X), size=3, replace= False)
self.LogWriter.sample_image_per_epoch(prediction= model.predict(dataloaders[phase].dataset.X[index], self.device),
ground_truth= dataloaders[phase].dataset.y[index],
phase= phase
phase= phase,
epoch= epoch)
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
......@@ -318,10 +318,10 @@ class Solver():
# We are not loading the model_name as we might want to pre-train a model and then use it.
self.model.load_state_dict = checkpoint['state_dict']
self.optimizer.load_state_dict = checkpoint['optimizer']
self.scheduler.load_state_dict = checkpoint['scheduler']
self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
for state in self.optimizer.state.values():
for key, value in state.items{}:
for key, value in state.items():
if torch.is_tensor(value):
state[key] = value.to(self.device)
......
......@@ -20,6 +20,7 @@ import logging
import numpy as np
import re
from textwrap import wrap
import torch
# The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it.
# More here: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html
......@@ -139,7 +140,7 @@ class LogWriter():
"""
print("Loss for Iteration {} is: {}".format(batch_index, loss_per_iteration))
self.log_writer['train'].add_scalar(tag= 'loss / iteration', loss_per_iteration, iteration)
self.log_writer['train'].add_scalar('loss / iteration', loss_per_iteration, iteration)
def loss_per_epoch(self, losses, phase, epoch):
"""Log function
......@@ -164,7 +165,7 @@ class LogWriter():
loss = np.mean(losses)
print("Loss for Epoch {} of {} is: {}".format(epoch, phase, loss))
self.log_writer[phase].add_scalar(tag= 'loss / iteration', loss, epoch)
self.log_writer[phase].add_scalar('loss / iteration', loss, epoch)
# Currently, no confusion matrix is required
# TODO: add a confusion matrix per epoch and confusion matrix plot functions if required
......@@ -190,12 +191,11 @@ class LogWriter():
print("Dice Score is being calculated...", end='', flush= True)
dice_score = evaluation.dice_score_calculator(outputs, correct_labels, self.number_of_classes)
mean_dice_score = torch.mean(dice_score)
self.plot_dice_score(dice_score, phase, plot_name='dice_score_per_epoch', title='Dice Score', epoch)
self.log_writer[phase].add_scalar(tag= 'loss / iteration', loss, epoch)
self.plot_dice_score(dice_score, phase, plot_name='dice_score_per_epoch', title='Dice Score', epochs=epoch)
print("Dice Score calculated successfully")
return mean_dice_score
return mean_dice_score.item()
def plot_dice_score(self, dice_score, phase, plot_name, title, epochs):
def plot_dice_score(self, dice_score, phase, plot_name, title, epochs=None):
"""Function plotting dice score for multiple epochs
This function plots the dice score for each epoch.
......@@ -223,7 +223,7 @@ class LogWriter():
ax.set_xticklabels(self.labels)
ax.xaxis.tick_bottom()
if step:
if epochs:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure, global_step= epochs)
else:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment