solver.py 9.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""Brain Mapper U-Net Solver

Description:
-------------
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.

Usage
-------------
To use this module, import it and instantiate is as you wish:

    from solver import Solver

"""

import os
import numpy as np
import torch
18
from datetime import datetime
19
20
from utils.losses import MSELoss
from utils.data_utils import create_folder
21
from utils.data_logging_utils import LogWriter
22
from torch.optim import lr_scheduler
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'

class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs

    Returns:
        trained model(?) - working on this!

    Raises:
        None
    """
    
    def __init__(self,
                model,
                device,
                number_of_classes,
                experiment_name,
                optimizer = torch.optim.Adam,
                optimizer_arguments = {},
64
                loss_function =  MSELoss(),
65
66
67
68
69
70
71
72
73
74
75
76
77
                model_name = 'BrainMapper',
                labels = None,
                number_epochs = 10,
                loss_log_period = 5,
                learning_rate_scheduler_step_size = 5,
                learning_rate_scheduler_gamma = 0.5,
                use_last_checkpoint = True,
                experiment_directory = 'experiments',
                logs_directory = 'logs'
                ):

        self.model = model
        self.device = device
78
79
80
81
82
83
84
85
86
87
88
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
        else:
            self.loss_function = loss_function

        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
89
        
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
                                                            step_size = learning_rate_scheduler_step_size,
                                                            gamma= learning_rate_scheduler_gamma)
        
        self.use_last_checkpoint = use_last_checkpoint

        experiment_directory_path = os.join.path(experiment_directory, experiment_name)
        self.experiment_directory_path = experiment_directory_path
        
        create_folder(experiment_directory_path)
        create_folder(os.join.path(experiment_directory_path, checkpoint_directory))

        self.start_epoch = 1
        self.start_iteration = 1
        self.best_mean_score = 0
106
        self.best_mean_score_epoch = 0
107
108
109

        if use_last_checkpoint:
            self.load_checkpoint()
110

111
112
113
114
115
116
        self.LogWriter = LogWriter(number_of_classes= number_of_classes,
                                    logs_directory= logs_directory,
                                    experiment_name= experiment_name,
                                    use_last_checkpoint= use_last_checkpoint,
                                    labels= labels)

117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    def train(self, train_loader, test_loader):
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
            test_loader (class):  Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader)

        Returns:
            None: trained model

        Raises:
            None
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
        dataloaders = {'train': train_loader, 'test': test_loader}

        if torch.cuda.is_available():
            torch.cuda.empty_cache() # clear memory
            model.cuda(self.device) # Moving the model to GPU

        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
        print('Device Type: {}'.format(torch.cuda.get_device_name(self.device)))
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

            for phase in ['train', 'test']:
                print('-> Phase: {}'.format(phase))

                losses = []
                outputs = []
                y_values = []

                if phase == 'train':
                    model.train()
                    learning_rate_scheduler.step()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
                    y = sampled_batch[1].type(torch.LondTensor)

                    if model.is_cuda():
                        X = X.cuda(self.device, non_blocking= True)
                        y = y.cuda(self.device, non_blocking= True)

                    y_hat = model(X)   # Forward pass

                    loss = self.loss_function(y_hat, y) # Loss computation

                    if phase == 'train':
                        optimizer.zero_grad() # Zero the parameter gradients
                        loss.backward() # Backward propagation
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:
                            
187
                            self.LogWriter.loss_per_iteration(self, loss.item(), batch_index, iteration)
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

                        iteration += 1         

                    losses.append(loss.item()) 
                    outputs.append(torch.max(y_hat, dim=1)[1].cpu())
                    y_values.append(y.cpu())

                    # Clear the memory

                    del X, y, y_hat, loss
                    torch.cuda.empty_cache()

                    if phase == 'test':
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():
                    output_array, y_array = torch.cat(outputs), torch.cat(y_values)

209
210
211
212
213
214
215
                    self.LogWriter.loss_per_epoch(losses, phase, epoch)

                    dice_score_mean = self.LogWriter.dice_score_per_epoch(phase, output_array, y_array, epoch)
                    if phase === 'test':
                        if dice_score_mean > self.best_mean_score:
                            self.best_mean_score = dice_score_mean
                            self.best_mean_score_epoch = epoch
216

217
218
219
220
221
                    index = np.random.choice(len(dataloaders[phase].dataset.X), size=3, replace= False)
                    self.LogWriter.sample_image_per_epoch(prediction= model.predict(dataloaders[phase].dataset.X[index], self.device)
                                                            ground_truth= dataloaders[phase].dataset.y[index],\
                                                            phase= phase
                                                            epoch= epoch)
222
223
224
225

            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
            self.save_checkpoint() # TODO - write function and save the checkpoint!

226
227
        
        self.LogWriter.close()
228
229
230
231
232
233
234
235
236
237

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
        print('****************************************************************')

        # TODO: MAKE SURE any log writer function is closed!
238

239
    def save_model(self):
240
241
        pass

242
    def save_checkpoint(self):
243
244
        pass

245
    def load_checkpoint(self):
246
247
        pass

248
    def _load_checkpoint_file(self):
249
250
251
        # Name is private = can't be called outisde of this module
        pass