solver.py 12.8 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from datetime import datetime
20
21
from utils.losses import MSELoss
from utils.data_utils import create_folder
22
from utils.data_logging_utils import LogWriter
23
from torch.optim import lr_scheduler
24
25
26
27

checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs

    Returns:
        trained model(?) - working on this!

    """
56

57
    def __init__(self,
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
                 optimizer=torch.optim.Adam,
                 optimizer_arguments={},
                 loss_function=MSELoss(),
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
                 logs_directory='logs'
                 ):
75
76
77

        self.model = model
        self.device = device
78
79
80
81
82
83
84
85
86
87
88
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
        else:
            self.loss_function = loss_function

        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
89

90
91
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
92
93
94
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

95
96
        self.use_last_checkpoint = use_last_checkpoint

97
        experiment_directory_path = os.path.join(
98
            experiment_directory, experiment_name)
99
        self.experiment_directory_path = experiment_directory_path
100

101
        create_folder(experiment_directory)
102
        create_folder(experiment_directory_path)
103
        create_folder(os.path.join(
104
            experiment_directory_path, checkpoint_directory))
105
106
107
108

        self.start_epoch = 1
        self.start_iteration = 1
        self.best_mean_score = 0
109
        self.best_mean_score_epoch = 0
110

111
112
113
114
115
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
116

117
118
119
120
        if use_last_checkpoint:
            self.load_checkpoint()


121
122
123
124
125
126
127
128
129
130
    def train(self, train_loader, test_loader):
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
            test_loader (class):  Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader)

        Returns:
131
            trained model
132
133
134
135
136
137
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
        dataloaders = {'train': train_loader, 'test': test_loader}

        if torch.cuda.is_available():
138
139
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
140
141
142
143
144

        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
145
146
147
148
        if torch.cuda.is_available():
            print('Device Type: {}'.format(torch.cuda.get_device_name(self.device)))
        else:
            print('Device Type: {}'.format(self.device))
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

            for phase in ['train', 'test']:
                print('-> Phase: {}'.format(phase))

                losses = []
                outputs = []
                y_values = []

                if phase == 'train':
                    model.train()
                    learning_rate_scheduler.step()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
173
                    y = sampled_batch[1].type(torch.FloatTensor)
174

175
176
177
178
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
                    X = torch.unsqueeze(X, dim= 1)
                    y = torch.unsqueeze(y, dim= 1)

179
                    if model.is_cuda():
180
181
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
182
183
184

                    y_hat = model(X)   # Forward pass

185
                    loss = self.loss_function(y_hat, y)  # Loss computation
186
187

                    if phase == 'train':
188
189
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
190
191
192
193
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

194
195
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
196

197
198
199
                        iteration += 1

                    losses.append(loss.item())
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
                    outputs.append(torch.max(y_hat, dim=1)[1].cpu())
                    y_values.append(y.cpu())

                    # Clear the memory

                    del X, y, y_hat, loss
                    torch.cuda.empty_cache()

                    if phase == 'test':
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():
215
216
                    output_array, y_array = torch.cat(
                        outputs), torch.cat(y_values)
217

218
219
                    self.LogWriter.loss_per_epoch(losses, phase, epoch)

220
221
                    dice_score_mean = self.LogWriter.dice_score_per_epoch(
                        phase, output_array, y_array, epoch)
222
                    if phase == 'test':
223
224
225
                        if dice_score_mean > self.best_mean_score:
                            self.best_mean_score = dice_score_mean
                            self.best_mean_score_epoch = epoch
226

227
228
229
230
231
232
                    index = np.random.choice(
                        len(dataloaders[phase].dataset.X), size=3, replace=False)
                    self.LogWriter.sample_image_per_epoch(prediction=model.predict(dataloaders[phase].dataset.X[index], self.device),
                                                          ground_truth=dataloaders[phase].dataset.y[index],
                                                          phase=phase,
                                                          epoch=epoch)
233
234

            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
235

236
237
238
239
240
241
242
            self.save_checkpoint(state={'epoch': epoch + 1,
                                        'start_iteration': iteration + 1,
                                        'arch': self.model_name,
                                        'state_dict': model.state_dict(),
                                        'optimizer': optimizer.state_dict(),
                                        'scheduler': learning_rate_scheduler.state_dict()
                                        },
243
244
245
                                 filename=os.path.join(self.experiment_directory_path, checkpoint_directory,
                                                       'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
                                 )
246

247
        self.LogWriter.close()
248
249
250
251
252
253
254
255
256

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
        print('****************************************************************')

257
258
259
260
261
262
263
264
265
266
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
267

268
    def load_checkpoint(self, epoch=None):
269
270
271
272
273
274
275
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
276

277
        if epoch is not None:
278
279
            checkpoint_file_path = os.path.join(
                self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
280
281
            self._checkpoint_reader(checkpoint_file_path)
        else:
282
283
            universal_path = os.path.join(
                self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
284
285
286
287
288
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
289
290
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
291
292
293
                self._checkpoint_reader(checkpoint_file_path)

            else:
294
295
                self.LogWriter.log("No Checkpoint found at {}".format(
                    os.path.join(self.experiment_directory_path, checkpoint_directory)))
296

297
298
299
300
301
302
303
304
305
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

306
307
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
308
309
310
311

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
312
        # We are not loading the model_name as we might want to pre-train a model and then use it.
313
314
        self.model.load_state_dict = checkpoint['state_dict']
        self.optimizer.load_state_dict = checkpoint['optimizer']
315
        self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
316
317

        for state in self.optimizer.state.values():
318
            for key, value in state.items():
319
320
321
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

322
323
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))