Commit f7e0e7b2 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added multiple losses,

parent cc8a04ca
......@@ -62,7 +62,9 @@ class Solver():
experiment_name,
optimizer,
optimizer_arguments={},
loss_function=MSELoss(),
# loss_function=MSELoss(),
# loss_function=torch.nn.L1Loss(),
loss_function=torch.nn.CosineEmbeddingLoss(),
model_name='BrainMapper',
labels=None,
number_epochs=10,
......@@ -83,8 +85,12 @@ class Solver():
if torch.cuda.is_available():
self.loss_function = loss_function.cuda(device)
self.MSE = MSELoss().cuda(device)
else:
self.loss_function = loss_function
self.MSE = MSELoss()
# self.loss_function = loss_function
self.model_name = model_name
self.labels = labels
......@@ -152,6 +158,7 @@ class Solver():
previous_checkpoint = None
previous_loss = None
previous_MSE = None
print('****************************************************************')
print('TRAINING IS STARTING!')
......@@ -175,6 +182,7 @@ class Solver():
print('-> Phase: {}'.format(phase))
losses = []
MSEs = []
if phase == 'train':
model.train()
......@@ -188,8 +196,9 @@ class Solver():
# We add an extra dimension (~ number of channels) for the 3D convolutions.
X = torch.unsqueeze(X, dim=1)
y = torch.unsqueeze(y, dim=1)
MNI152_T1_2mm_brain_mask = torch.unsqueeze(torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
MNI152_T1_2mm_brain_mask = torch.unsqueeze(
torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
if model.test_if_cuda:
X = X.cuda(self.device, non_blocking=True)
......@@ -201,7 +210,13 @@ class Solver():
y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
loss = self.loss_function(y_hat, y) # Loss computation
# loss = self.loss_function(y_hat, y) # Loss computation
loss = self.loss_function(
y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
# We also calculate a separate MSE for cost function comparison!
MSE = self.MSE(y_hat, y)
MSEs.append(MSE.item())
if phase == 'train':
optimizer.zero_grad() # Zero the parameter gradients
......@@ -219,7 +234,7 @@ class Solver():
# Clear the memory
del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask
del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
torch.cuda.empty_cache()
if phase == 'validation':
......@@ -233,9 +248,14 @@ class Solver():
if phase == 'train':
self.LogWriter.loss_per_epoch(losses, phase, epoch)
self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
elif phase == 'validation':
self.LogWriter.loss_per_epoch(losses, phase, epoch, previous_loss=previous_loss)
self.LogWriter.loss_per_epoch(
losses, phase, epoch, previous_loss=previous_loss)
previous_loss = np.mean(losses)
self.LogWriter.MSE_per_epoch(
MSEs, phase, epoch, previous_loss=previous_MSE)
previous_MSE = np.mean(MSEs)
if phase == 'validation':
early_stop, save_checkpoint = self.EarlyStopping(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment