Commit dc68be2b authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added function for sampling y and y_hat images

parent 6ef6d440
...@@ -18,6 +18,7 @@ import matplotlib.pyplot as plt ...@@ -18,6 +18,7 @@ import matplotlib.pyplot as plt
import shutil import shutil
import logging import logging
import numpy as np import numpy as np
import re
# The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it. # The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it.
# More here: # More here:
...@@ -77,6 +78,23 @@ class LogWriter(): ...@@ -77,6 +78,23 @@ class LogWriter():
file_handler = logging.FileHandler("{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs")) file_handler = logging.FileHandler("{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs"))
self.logger.addHandler(file_handler) self.logger.addHandler(file_handler)
def labels_generator(self, labels):
""" Label Generator Function
This function processess an input array of labels.
labels (arr): Vector/Array of labels (if applicable)
label_classes (list): List of processed labels
return pass
def log(self, message): def log(self, message):
"""Log function """Log function
...@@ -201,17 +219,45 @@ class LogWriter(): ...@@ -201,17 +219,45 @@ class LogWriter():
else: else:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure) self.log_writer[phase].add_figure(plot_name + '/' + phase, figure)
def plot_evaluation_box(self): # Currently, also no need for an evaluation box plot
def sample_image_per_epoch(self): def sample_image_per_epoch(self, prediction, ground_truth, phase, epoch):
pass """Function plotting mirrored images
This function plots a predicted and a grond truth images side-by-side.
prediction (torch.tensor): Predicted image after passing throught the network
ground_truth (torch.tensor): Labelled ground truth image
phase (str): Current run mode or phase
epoch (int): Current epoch value
print("Sample Image is being loaded...", end='', flush= True)
figure, ax = plt.subplots(nrows = len(prediction), ncols = 2)
for i in range(len(prediction)):
ax[i][0].set_title("Predicted Image")
ax[i][1].set_title('Ground Truth Image')
self.log_writer[phase].add_figure('sample_prediction/'+phase, figure, epoch)
print("Sample Image successfully loaded!")
def graph(self): def graph(self):
pass pass
def close(self): def close(self):
pass pass
def labels_generator(self):
return pass
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment