BrainMapperUNet.py 5.91 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Architecture

Description:

5
6
    This folder contains the Pytorch implementation of the core U-net architecture.
    This arcitecture predicts functional connectivity rsfMRI from structural connectivity information from dMRI.
7

8
9
10
11
12
13
Usage:

    To use this module, import it and instantiate is as you wish:

        from BrainMapperUNet import BrainMapperUNet
        deep_learning_model = BrainMapperUnet(parameters)
14
15
16
17
18
19

"""

import numpy as np
import torch
import torch.nn as nn
20
import utils.modules as modules
21

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class BrainMapperUNet(nn.Module):
    """Architecture class BrainMapper U-net.

    This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.

    Args:
        parameters (dict): Contains information relevant parameters
        parameters = {
            'kernel_heigth': 5
            'kernel_width': 5
            'kernel_classification': 1
            'input_channels': 1
            'output_channels': 64
            'convolution_stride': 1
            'dropout': 0.2
            'pool_kernel_size': 2
            'pool_stride': 2
            'up_mode': 'upconv'
            'number_of_classes': 1
        }
43

44
45
    Returns:
        probability_map (torch.tensor): Output forward passed tensor through the U-net block
46
    """
47

48
    def __init__(self, parameters):
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        super(BrainMapperUNet, self).__init__()

        # TODO: currently, architecture based on QuickNAT - need to adjust parameter values accordingly!

        self.encoderBlock1 = modules.EncoderBlock(parameters)
        parameters['input_channels'] = parameters['output_channels']
        self.encoderBlock2 = modules.EncoderBlock(parameters)
        self.encoderBlock3 = modules.EncoderBlock(parameters)
        self.encoderBlock4 = modules.EncoderBlock(parameters)

        self.bottleneck = modules.ConvolutionalBlock(parameters)

        parameters['input_channels'] = parameters['output_channels'] * 2.0
        self.decoderBlock1 = modules.DecoderBlock(parameters)
        self.decoderBlock2 = modules.DecoderBlock(parameters)
        self.decoderBlock3 = modules.DecoderBlock(parameters)
        self.decoderBlock4 = modules.DecoderBlock(parameters)

        parameters['input_channels'] = parameters['output_channels']
        self.classifier = modules.ClassifierBlock(parameters)
69
70

    def forward(self, X):
71
72
73
74
75
76
77
78
79
80
        """Forward pass for U-net

        Function computing the forward pass through the U-Net
        The input to the function is the dMRI map

        Args:
            X (torch.tensor): Input dMRI map, shape = (N x C x H x W) 

        Returns:
            probability_map (torch.tensor): Output forward passed tensor through the U-net block
81
        """
82
83

        Y_encoder_1, Y_np1, pool_indices1 = self.encoderBlock1.forward(X)
84
85
86
87
88
89
        Y_encoder_2, Y_np2, pool_indices2 = self.encoderBlock2.forward(
            Y_encoder_1)
        Y_encoder_3, Y_np3, pool_indices3 = self.encoderBlock3.forward(
            Y_encoder_2)
        Y_encoder_4, Y_np4, pool_indices4 = self.encoderBlock4.forward(
            Y_encoder_3)
90
91
92

        Y_bottleNeck = self.bottleneck.forward(Y_encoder_4)

93
94
95
96
97
98
99
100
        Y_decoder_1 = self.decoderBlock1.forward(
            Y_bottleNeck, Y_np4, pool_indices4)
        Y_decoder_2 = self.decoderBlock2.forward(
            Y_decoder_1, Y_np3, pool_indices3)
        Y_decoder_3 = self.decoderBlock3.forward(
            Y_decoder_2, Y_np2, pool_indices2)
        Y_decoder_4 = self.decoderBlock4.forwrad(
            Y_decoder_3, Y_np1, pool_indices1)
101
102

        probability_map = self.classifier.forward(Y_decoder_4)
103

104
        return probability_map
105

106
    def save(self, path):
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
107
        """Model Saver
108

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
109
110
        Function saving the model with all its parameters to a given path.
        The path must end with a *.model argument.
111

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
112
113
114
        Args:
            path (str): Path string
        """
115

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
116
117
        print("Saving Model... {}".format(path))
        torch.save(self, path)
118
119
120
121
122
123
124
125
126
127
128
129

    @property
    def test_if_cuda(self):
        """Cuda Test

        This function tests if the model parameters are allocated to a CUDA enabled GPU.

        Returns:
            bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
        """

        return next(self.parameters()).is_cuda
130
131

    def predict(self, X, device=0):
132
        """Post-training Output Prediction
133

134
135
136
137
        This function predicts the output of the of the U-net post-training

        Args:
            X (torch.tensor): input dMRI volume
138
            device (int/str): Device type used for training (int - GPU id, str- CPU)
139
140
141
142

        Returns:
            prediction (ndarray): predicted output after training

143
        """
144
        self.eval()  # PyToch module setting network to evaluation mode
145
146
147
148
149
150
151
152
153

        if type(X) is np.ndarray:
            X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
        elif type(X) is torch.Tensor and not X.is_cuda:
            X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)

        # .cuda() call transfers the densor from the CPU to the GPU if that is the case.
        # Non-blocking argument lets the caller bypas synchronization when necessary

154
        with torch.no_grad():  # Causes operations to have no gradients
155
156
157
158
            output = self.forward(X)

        _, idx = torch.max(output, 1)

159
160
161
        # We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
        idx = idx.data.cpu().numpy()

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        prediction = np.squeeze(idx)

        del X, output, idx

        return prediction

# if __name__ == '__main__':

#     # For debugging - To be deleted later! TODO

#     parameters = {
#         'kernel_heigth': 5,
#         'kernel_width': 5,
#         'kernel_classification': 1,
#         'input_channels': 1,
#         'output_channels': 64,
#         'convolution_stride': 1,
#         'dropout': 0.2,
#         'pool_kernel_size': 2,
#         'pool_stride': 2,
#         'up_mode': 'upconv',
#         'number_of_classes': 1
#     }
#     network = BrainMapperUNet(parameters)