Commit 4f522872 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added masking between network output and loss calculation

parent d12bcf8e
...@@ -16,6 +16,7 @@ import numpy as np ...@@ -16,6 +16,7 @@ import numpy as np
import torch import torch
import glob import glob
from fsl.data.image import Image
from datetime import datetime from datetime import datetime
from utils.losses import MSELoss from utils.losses import MSELoss
from utils.data_utils import create_folder from utils.data_utils import create_folder
...@@ -123,6 +124,9 @@ class Solver(): ...@@ -123,6 +124,9 @@ class Solver():
if use_last_checkpoint: if use_last_checkpoint:
self.load_checkpoint() self.load_checkpoint()
self.MNI_152_2mm_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
def train(self, train_loader, validation_loader): def train(self, train_loader, validation_loader):
"""Training Function """Training Function
...@@ -189,6 +193,10 @@ class Solver(): ...@@ -189,6 +193,10 @@ class Solver():
y_hat = model(X) # Forward pass y_hat = model(X) # Forward pass
### Masking goes here
y_hat = torch.mul(y_hat, self.MNI_152_2mm_mask)
###
loss = self.loss_function(y_hat, y) # Loss computation loss = self.loss_function(y_hat, y) # Loss computation
if phase == 'train': if phase == 'train':
...@@ -237,9 +245,9 @@ class Solver(): ...@@ -237,9 +245,9 @@ class Solver():
filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory, filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
) )
if epoch != self.start_epoch: # if epoch != self.start_epoch:
os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory, # os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension)) # 'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
if phase == 'train': if phase == 'train':
learning_rate_scheduler.step() learning_rate_scheduler.step()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment