Commit 397a0fc0 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

wrote the forward function for U-net

parent 74c4b384
......@@ -70,10 +70,37 @@ class BrainMapperUNet(nn.Module):
self.classifier = modules.ClassifierBlock(parameters)
def forward(self, X):
"""Forward pass for U-net
Function computing the forward pass through the U-Net
The input to the function is the dMRI map
X (torch.tensor): Input dMRI map, shape = (N x C x H x W)
probability_map (torch.tensor): Output forward passed tensor through the U-net block
return None
Y_encoder_1, Y_np1, pool_indices1 = self.encoderBlock1.forward(X)
Y_encoder_2, Y_np2, pool_indices2 = self.encoderBlock2.forward(Y_encoder_1)
Y_encoder_3, Y_np3, pool_indices3 = self.encoderBlock3.forward(Y_encoder_2)
Y_encoder_4, Y_np4, pool_indices4 = self.encoderBlock4.forward(Y_encoder_3)
Y_bottleNeck = self.bottleneck.forward(Y_encoder_4)
Y_decoder_1 = self.decoderBlock1.forward(Y_bottleNeck, Y_np4, pool_indices4)
Y_decoder_2 = self.decoderBlock2.forward(Y_decoder_1, Y_np3, pool_indices3)
Y_decoder_3 = self.decoderBlock3.forward(Y_decoder_2, Y_np2, pool_indices2)
Y_decoder_4 = self.decoderBlock4.forwrad(Y_decoder_3, Y_np1, pool_indices1)
probability_map = self.classifier.forward(Y_decoder_4)
return probability_map
def enable_test_dropout(self):
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment