Commit 054f064b authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

built constructor, defined functions

parent 2b37d7b3
......@@ -13,3 +13,98 @@ To use content from this folder, import the functions and instantiate them as yo
"""
import os
import matplotlib.pyplot as plt
import shutil
import logging
# The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it.
# More here: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html
from tensorboardX import SummaryWriter
plt.axis('scaled')
class LogWriter():
"""Log Writer class for the BrainMapper U-net.
This class contains the pytorch implementation of the several logging functions required for the BrainMapper project.
These functions are designed to keep track of progress during training, and also aid debugging.
Args:
number_of_classes (int): Number of classes
logs_directory (str): Directory for outputing training logs
experiment_name (str): Name of the experiment
use_last_checkpoint (bool): Flag for loading the previous checkpoint
labels (arr): Vector/Array of labels (if applicable)
confusion_matrix_cmap (class): Colour Map to be used for the Conusion Matrix
Returns:
None
Raises:
None
"""
def __init__(self, number_of_classes, logs_directory, experiment_name, use_last_checkpoint=False, labels=None, confusion_matrix_cmap= plt.cm.Blues):
self.number_of_classes = number_of_classes
training_logs_directory = os.path.join(logs_directory, experiment_name, "train")
testing_logs_directory = os.path.join(logs_directory, experiment_name, "test")
# If the logs directory exist, we clear their contents to allow new logs to be created
if not use_last_checkpoint:
if os.path.exists(training_logs_directory):
shutil.rmtree(training_logs_directory)
if os.path.exists(testing_logs_directory):
shutil.rmtree(testing_logs_directory)
self.log_writer = {
'train': SummaryWriter(logdir= training_logs_directory),
'test:': SummaryWriter(logdir= testing_logs_directory)
}
self.confusion_matrix_color_map = confusion_matrix_cmap
self.current_iteration = 1
self.labels = self.labels_generator(labels)
self.logger = logging.getLogger()
file_handler = logging.FileHandler("{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs"))
self.logger.addHandler(file_handler)
def log(self):
pass
def loss_per_iteration(self):
pass
def loss_per_epoch(self):
pass
def confusion_matrix_per_epoch(self):
pass
def plot_confusion_matrix(self):
pass
def dice_score_per_epoch(self):
pass
def plot_dice_score(self):
pass
def plot_evaluation_box(self):
pass
def sample_image_per_epoch(self):
pass
def graph(self):
pass
def close(self):
pass
def labels_generator(self):
pass
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment