Commit 31c05c0c authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

fixed bug with loading mask onto gpu

parent d189eaa8
...@@ -109,8 +109,6 @@ class Solver(): ...@@ -109,8 +109,6 @@ class Solver():
self.start_epoch = 1 self.start_epoch = 1
self.start_iteration = 1 self.start_iteration = 1
# self.best_mean_score = 0
# self.best_mean_score_epoch = 0
self.LogWriter = LogWriter(number_of_classes=number_of_classes, self.LogWriter = LogWriter(number_of_classes=number_of_classes,
logs_directory=logs_directory, logs_directory=logs_directory,
...@@ -124,8 +122,7 @@ class Solver(): ...@@ -124,8 +122,7 @@ class Solver():
if use_last_checkpoint: if use_last_checkpoint:
self.load_checkpoint() self.load_checkpoint()
self.MNI_152_2mm_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data) self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
def train(self, train_loader, validation_loader): def train(self, train_loader, validation_loader):
"""Training Function """Training Function
...@@ -186,16 +183,17 @@ class Solver(): ...@@ -186,16 +183,17 @@ class Solver():
# We add an extra dimension (~ number of channels) for the 3D convolutions. # We add an extra dimension (~ number of channels) for the 3D convolutions.
X = torch.unsqueeze(X, dim=1) X = torch.unsqueeze(X, dim=1)
y = torch.unsqueeze(y, dim=1) y = torch.unsqueeze(y, dim=1)
MNI152_T1_2mm_brain_mask = self.MNI152_T1_2mm_brain_mask
if model.test_if_cuda: if model.test_if_cuda:
X = X.cuda(self.device, non_blocking=True) X = X.cuda(self.device, non_blocking=True)
y = y.cuda(self.device, non_blocking=True) y = y.cuda(self.device, non_blocking=True)
MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(self.device, non_blocking=True)
y_hat = model(X) # Forward pass y_hat = model(X) # Forward pass & Masking
### Masking goes here y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
y_hat = torch.mul(y_hat, self.MNI_152_2mm_mask)
###
loss = self.loss_function(y_hat, y) # Loss computation loss = self.loss_function(y_hat, y) # Loss computation
...@@ -215,7 +213,7 @@ class Solver(): ...@@ -215,7 +213,7 @@ class Solver():
# Clear the memory # Clear the memory
del X, y, y_hat, loss del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask
torch.cuda.empty_cache() torch.cuda.empty_cache()
if phase == 'validation': if phase == 'validation':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment