Commit dc68be2b authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added function for sampling y and y_hat images

parent 6ef6d440
......@@ -18,6 +18,7 @@ import matplotlib.pyplot as plt
import shutil
import logging
import numpy as np
import re
# The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it.
# More here: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html
......@@ -77,6 +78,23 @@ class LogWriter():
file_handler = logging.FileHandler("{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs"))
self.logger.addHandler(file_handler)
def labels_generator(self, labels):
""" Label Generator Function
This function processess an input array of labels.
Args:
labels (arr): Vector/Array of labels (if applicable)
Returns:
label_classes (list): List of processed labels
Raises:
None
"""
return pass
def log(self, message):
"""Log function
......@@ -201,17 +219,45 @@ class LogWriter():
else:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure)
def plot_evaluation_box(self):
pass
# Currently, also no need for an evaluation box plot
def sample_image_per_epoch(self):
pass
def sample_image_per_epoch(self, prediction, ground_truth, phase, epoch):
"""Function plotting mirrored images
This function plots a predicted and a grond truth images side-by-side.
Args:
prediction (torch.tensor): Predicted image after passing throught the network
ground_truth (torch.tensor): Labelled ground truth image
phase (str): Current run mode or phase
epoch (int): Current epoch value
Returns:
None
Raises
None
"""
print("Sample Image is being loaded...", end='', flush= True)
figure, ax = plt.subplots(nrows = len(prediction), ncols = 2)
for i in range(len(prediction)):
ax[i][0].imshow(prediction[i])
ax[i][0].set_title("Predicted Image")
ax[i][0].axis('off')
ax[i][1].imshow(ground_truth[i])
ax[i][1].set_title('Ground Truth Image')
ax[i][1].axis('off')
figure.set_tight_layout()
self.log_writer[phase].add_figure('sample_prediction/'+phase, figure, epoch)
print("Sample Image successfully loaded!")
def graph(self):
pass
def close(self):
pass
def labels_generator(self):
return pass
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment