Commit 31c05c0c authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

fixed bug with loading mask onto gpu

parent d189eaa8
......@@ -109,8 +109,6 @@ class Solver():
self.start_epoch = 1
self.start_iteration = 1
# self.best_mean_score = 0
# self.best_mean_score_epoch = 0
self.LogWriter = LogWriter(number_of_classes=number_of_classes,
logs_directory=logs_directory,
......@@ -124,8 +122,7 @@ class Solver():
if use_last_checkpoint:
self.load_checkpoint()
self.MNI_152_2mm_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
def train(self, train_loader, validation_loader):
"""Training Function
......@@ -187,15 +184,16 @@ class Solver():
X = torch.unsqueeze(X, dim=1)
y = torch.unsqueeze(y, dim=1)
MNI152_T1_2mm_brain_mask = self.MNI152_T1_2mm_brain_mask
if model.test_if_cuda:
X = X.cuda(self.device, non_blocking=True)
y = y.cuda(self.device, non_blocking=True)
MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(self.device, non_blocking=True)
y_hat = model(X) # Forward pass
y_hat = model(X) # Forward pass & Masking
### Masking goes here
y_hat = torch.mul(y_hat, self.MNI_152_2mm_mask)
###
y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
loss = self.loss_function(y_hat, y) # Loss computation
......@@ -215,7 +213,7 @@ class Solver():
# Clear the memory
del X, y, y_hat, loss
del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask
torch.cuda.empty_cache()
if phase == 'validation':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment