Commit 7c0603da authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

fixed data logging bugs

parent 12ce71d4
...@@ -277,13 +277,12 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva ...@@ -277,13 +277,12 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
_ = evaluations.evaluate_dice_score(trained_model_path=evaluation_parameters['trained_model_path'], _ = evaluations.evaluate_dice_score(trained_model_path=evaluation_parameters['trained_model_path'],
number_of_classes=network_parameters['number_of_classes'], number_of_classes=network_parameters['number_of_classes'],
data_directory=evaluation_parameters['data_directory'], data_directory=evaluation_parameters['data_directory'],
targets_directory=evaluation_parameters[ targets_directory=evaluation_parameters['targets_directory'],
'targets_directory'], data_list=evaluation_parameters['data_list'],
data_list=evaluation_parameters['data_list'], orientation=evaluation_parameters['orientation'],
orientation=evaluation_parameters['orientation'], prediction_output_path=prediction_output_path,
prediction_output_path=prediction_output_path, device=misc_parameters['device'],
device=misc_parameters['device'], LogWriter=logWriter
LogWriter=logWriter
) )
logWriter.close() logWriter.close()
......
...@@ -122,7 +122,8 @@ class Solver(): ...@@ -122,7 +122,8 @@ class Solver():
if use_last_checkpoint: if use_last_checkpoint:
self.load_checkpoint() self.load_checkpoint()
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data) self.MNI152_T1_2mm_brain_mask = torch.from_numpy(
Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
def train(self, train_loader, validation_loader): def train(self, train_loader, validation_loader):
"""Training Function """Training Function
...@@ -174,22 +175,19 @@ class Solver(): ...@@ -174,22 +175,19 @@ class Solver():
for batch_index, sampled_batch in enumerate(dataloaders[phase]): for batch_index, sampled_batch in enumerate(dataloaders[phase]):
X = sampled_batch[0].type(torch.FloatTensor) X = sampled_batch[0].type(torch.FloatTensor)
# X = ( X - X.min() ) / ( X.max() - X.min() )
# X = ( X - X.mean() ) / X.std()
y = sampled_batch[1].type(torch.FloatTensor) y = sampled_batch[1].type(torch.FloatTensor)
# We add an extra dimension (~ number of channels) for the 3D convolutions. # We add an extra dimension (~ number of channels) for the 3D convolutions.
X = torch.unsqueeze(X, dim=1) X = torch.unsqueeze(X, dim=1)
y = torch.unsqueeze(y, dim=1) y = torch.unsqueeze(y, dim=1)
MNI152_T1_2mm_brain_mask = self.MNI152_T1_2mm_brain_mask MNI152_T1_2mm_brain_mask = self.MNI152_T1_2mm_brain_mask
if model.test_if_cuda: if model.test_if_cuda:
X = X.cuda(self.device, non_blocking=True) X = X.cuda(self.device, non_blocking=True)
y = y.cuda(self.device, non_blocking=True) y = y.cuda(self.device, non_blocking=True)
MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(self.device, non_blocking=True) MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
self.device, non_blocking=True)
y_hat = model(X) # Forward pass & Masking y_hat = model(X) # Forward pass & Masking
......
...@@ -82,30 +82,6 @@ class LogWriter(): ...@@ -82,30 +82,6 @@ class LogWriter():
"{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs")) "{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs"))
self.logger.addHandler(file_handler) self.logger.addHandler(file_handler)
def labels_generator(self, labels):
""" Label Generator Function
This function processess an input array of labels.
Args:
labels (arr): Vector/Array of labels (if applicable)
Returns:
label_classes (list): List of processed labels
"""
label_classes = []
for label in labels:
label_class = re.sub(
r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', label)
label_class = ['\n'.join(wrap(element, 40))
for element in label_class]
label_classes.append(label_class)
return label_classes
def log(self, message): def log(self, message):
"""Log function """Log function
...@@ -131,7 +107,7 @@ class LogWriter(): ...@@ -131,7 +107,7 @@ class LogWriter():
print("Loss for Iteration {} is: {}".format( print("Loss for Iteration {} is: {}".format(
batch_index, loss_per_iteration)) batch_index, loss_per_iteration))
self.log_writer['train'].add_scalar( self.log_writer['train'].add_scalar(
'loss / iteration', loss_per_iteration, iteration) 'loss/iteration', loss_per_iteration, iteration)
def loss_per_epoch(self, losses, phase, epoch): def loss_per_epoch(self, losses, phase, epoch):
"""Log function """Log function
...@@ -150,31 +126,31 @@ class LogWriter(): ...@@ -150,31 +126,31 @@ class LogWriter():
loss = np.mean(losses) loss = np.mean(losses)
print("Loss for Epoch {} of {} is: {}".format(epoch, phase, loss)) print("Loss for Epoch {} of {} is: {}".format(epoch, phase, loss))
self.log_writer[phase].add_scalar('loss / iteration', loss, epoch) self.log_writer[phase].add_scalar('loss/epoch', loss, epoch)
# Currently, no confusion matrix is required def close(self):
# TODO: add a confusion matrix per epoch and confusion matrix plot functions if required """Close the log writer
def dice_score_per_epoch(self, phase, outputs, correct_labels, epoch): This function closes the two log writers.
"""Function calculating dice score for each epoch """
This function computes the dice score for each epoch. self.log_writer['train'].close()
self.log_writer['validation'].close()
def add_graph(self, model):
"""Produces network graph
This function produces the network graph
NOTE: Currently, the function suffers from bugs and is not implemented.
Args: Args:
phase (str): Current run mode or phase model (torch.nn.Module): Model to draw.
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
epoch (int): Current epoch value
""" """
print("Dice Score is being calculated...", end='', flush=True) self.log_writer['train'].add_graph(model)
dice_score = evaluation.dice_score_calculator(
outputs, correct_labels, self.number_of_classes) # DEPRECATED / UNDEBUGGED FUNCTIONS
mean_dice_score = torch.mean(dice_score)
self.plot_dice_score(
dice_score, phase, plot_name='dice_score_per_epoch', title='Dice Score', epochs=epoch)
print("Dice Score calculated successfully")
return mean_dice_score.item()
def plot_dice_score(self, dice_score, phase, plot_name, title, epochs=None): def plot_dice_score(self, dice_score, phase, plot_name, title, epochs=None):
"""Function plotting dice score for multiple epochs """Function plotting dice score for multiple epochs
...@@ -208,7 +184,26 @@ class LogWriter(): ...@@ -208,7 +184,26 @@ class LogWriter():
else: else:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure) self.log_writer[phase].add_figure(plot_name + '/' + phase, figure)
# Currently, also no need for an evaluation box plot def dice_score_per_epoch(self, phase, outputs, correct_labels, epoch):
"""Function calculating dice score for each epoch
This function computes the dice score for each epoch.
Args:
phase (str): Current run mode or phase
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
epoch (int): Current epoch value
"""
print("Dice Score is being calculated...", end='', flush=True)
dice_score = evaluation.dice_score_calculator(
outputs, correct_labels, self.number_of_classes)
mean_dice_score = torch.mean(dice_score)
self.plot_dice_score(
dice_score, phase, plot_name='dice_score_per_epoch', title='Dice Score', epochs=epoch)
print("Dice Score calculated successfully")
return mean_dice_score.item()
def sample_image_per_epoch(self, prediction, ground_truth, phase, epoch): def sample_image_per_epoch(self, prediction, ground_truth, phase, epoch):
"""Function plotting mirrored images """Function plotting mirrored images
...@@ -240,11 +235,26 @@ class LogWriter(): ...@@ -240,11 +235,26 @@ class LogWriter():
print("Sample Image successfully loaded!") print("Sample Image successfully loaded!")
def close(self): def labels_generator(self, labels):
"""Close the log writer """ Label Generator Function
This function closes the two log writers. This function processess an input array of labels.
Args:
labels (arr): Vector/Array of labels (if applicable)
Returns:
label_classes (list): List of processed labels
""" """
self.log_writer['train'].close() label_classes = []
self.log_writer['validation'].close()
for label in labels:
label_class = re.sub(
r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', label)
label_class = ['\n'.join(wrap(element, 40))
for element in label_class]
label_classes.append(label_class)
return label_classes
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment