Commit 4f522872 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added masking between network output and loss calculation

parent d12bcf8e
......@@ -16,6 +16,7 @@ import numpy as np
import torch
import glob
from fsl.data.image import Image
from datetime import datetime
from utils.losses import MSELoss
from utils.data_utils import create_folder
......@@ -123,6 +124,9 @@ class Solver():
if use_last_checkpoint:
self.load_checkpoint()
self.MNI_152_2mm_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
def train(self, train_loader, validation_loader):
"""Training Function
......@@ -189,6 +193,10 @@ class Solver():
y_hat = model(X) # Forward pass
### Masking goes here
y_hat = torch.mul(y_hat, self.MNI_152_2mm_mask)
###
loss = self.loss_function(y_hat, y) # Loss computation
if phase == 'train':
......@@ -237,9 +245,9 @@ class Solver():
filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
if epoch != self.start_epoch:
os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
# if epoch != self.start_epoch:
# os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
# 'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
if phase == 'train':
learning_rate_scheduler.step()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment