Commit 319f99f6 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

improved early stopping logging

parent 75530d88
......@@ -118,7 +118,7 @@ class Solver():
use_last_checkpoint=use_last_checkpoint,
labels=labels)
self.EarlyStopping = EarlyStopping()
self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
self.early_stop = False
if use_last_checkpoint:
......@@ -151,6 +151,7 @@ class Solver():
model.cuda(self.device) # Moving the model to GPU
previous_checkpoint = None
previous_loss = None
print('****************************************************************')
print('TRAINING IS STARTING!')
......@@ -188,7 +189,10 @@ class Solver():
X = torch.unsqueeze(X, dim=1)
y = torch.unsqueeze(y, dim=1)
MNI152_T1_2mm_brain_mask = self.MNI152_T1_2mm_brain_mask
print('X range:', torch.min(X), torch.max(X))
print('y range:', torch.min(y), torch.max(y))
MNI152_T1_2mm_brain_mask = torch.unsqueeze(torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
if model.test_if_cuda:
X = X.cuda(self.device, non_blocking=True)
......@@ -198,8 +202,12 @@ class Solver():
y_hat = model(X) # Forward pass & Masking
print('y_hat range:', torch.min(y_hat), torch.max(y_hat))
y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
print('y_hat masked range:', torch.min(y_hat), torch.max(y_hat))
loss = self.loss_function(y_hat, y) # Loss computation
if phase == 'train':
......@@ -230,7 +238,11 @@ class Solver():
with torch.no_grad():
self.LogWriter.loss_per_epoch(losses, phase, epoch)
if phase == 'train':
self.LogWriter.loss_per_epoch(losses, phase, epoch)
elif phase == 'validation':
self.LogWriter.loss_per_epoch(losses, phase, epoch, previous_loss=previous_loss)
previous_loss = np.mean(losses)
if phase == 'validation':
early_stop, save_checkpoint = self.EarlyStopping(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment