Commit af29f33a authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

worte the training function - need to add logger!

parent 16642040
......@@ -15,8 +15,10 @@ To use this module, import it and instantiate is as you wish:
import os
import numpy as np
import torch
from datetime import datetime
from utils.losses import MSELoss
from utils.data_utils import create_folder
from utils.data_logging_utils import #BLA - need to write something first
from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints'
......@@ -45,7 +47,6 @@ class Solver():
experiment_directory (str): Experiment output directory name
logs_directory (str): Directory for outputing training logs
Returns:
trained model(?) - working on this!
......@@ -108,8 +109,114 @@ class Solver():
self.load_checkpoint()
def train(self):
pass
def train(self, train_loader, test_loader):
"""Training Function
This function trains a given model using the provided training data.
Args:
train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
test_loader (class): Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader)
Returns:
None: trained model
Raises:
None
"""
model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
dataloaders = {'train': train_loader, 'test': test_loader}
if torch.cuda.is_available():
torch.cuda.empty_cache() # clear memory
model.cuda(self.device) # Moving the model to GPU
print('****************************************************************')
print('TRAINING IS STARTING!')
print('=====================')
print('Model Name: {}'.format(self.model_name))
print('Device Type: {}'.format(torch.cuda.get_device_name(self.device)))
start_time = datetime.now()
print('Started At: {}'.format(start_time))
print('----------------------------------------')
iteration = self.start_iteration
for epoch in range(self.start_epoch, self.number_epochs+1):
print("Epoch {}/{}".format(epoch, self.number_epochs))
for phase in ['train', 'test']:
print('-> Phase: {}'.format(phase))
losses = []
outputs = []
y_values = []
if phase == 'train':
model.train()
learning_rate_scheduler.step()
else:
model.eval()
for batch_index, sampled_batch in enumerate(dataloaders[phase]):
X = sampled_batch[0].type(torch.FloatTensor)
y = sampled_batch[1].type(torch.LondTensor)
if model.is_cuda():
X = X.cuda(self.device, non_blocking= True)
y = y.cuda(self.device, non_blocking= True)
y_hat = model(X) # Forward pass
loss = self.loss_function(y_hat, y) # Loss computation
if phase == 'train':
optimizer.zero_grad() # Zero the parameter gradients
loss.backward() # Backward propagation
optimizer.step()
if batch_index % self.loss_log_period == 0:
# TODO: NEED A FUNCTION that logs outputs for debugging!
# Here, I need it to log the loss, batch id and iteration number\
iteration += 1
losses.append(loss.item())
outputs.append(torch.max(y_hat, dim=1)[1].cpu())
y_values.append(y.cpu())
# Clear the memory
del X, y, y_hat, loss
torch.cuda.empty_cache()
if phase == 'test':
if batch_index != len(dataloaders[phase]) - 1:
print("#", end='', flush=True)
else:
print("100%", flush=True)
with torch.no_grad():
output_array, y_array = torch.cat(outputs), torch.cat(y_values)
# TODO - using log functions, record loss per epoch, maybe generated images per epoch, dice score and any other relevant metrics?
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
self.save_checkpoint() # TODO - write function and save the checkpoint!
print('----------------------------------------')
print('TRAINING IS COMPLETE!')
print('=====================')
end_time = datetime.now()
print('Completed At: {}'.format(end_time))
print('Training Duration: {}'.format(end_time - start_time))
print('****************************************************************')
# TODO: MAKE SURE any log writer function is closed!
def save_model(self):
pass
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment