solver.py 12.4 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from datetime import datetime
20
21
from utils.losses import MSELoss
from utils.data_utils import create_folder
22
from utils.data_logging_utils import LogWriter
23
from utils.early_stopping import EarlyStopping
24
from torch.optim import lr_scheduler
25
26
27
28

checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs

    Returns:
54
        trained model - working on this!
55
56

    """
57

58
    def __init__(self,
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
                 optimizer=torch.optim.Adam,
                 optimizer_arguments={},
                 loss_function=MSELoss(),
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
                 logs_directory='logs'
                 ):
76
77
78

        self.model = model
        self.device = device
79
80
81
82
83
84
85
86
87
88
89
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
        else:
            self.loss_function = loss_function

        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
90

91
92
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
93
94
95
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

96
97
        self.use_last_checkpoint = use_last_checkpoint

98
        experiment_directory_path = os.path.join(
99
            experiment_directory, experiment_name)
100
        self.experiment_directory_path = experiment_directory_path
101

102
        create_folder(experiment_directory)
103
        create_folder(experiment_directory_path)
104
        create_folder(os.path.join(
105
            experiment_directory_path, checkpoint_directory))
106
107
108

        self.start_epoch = 1
        self.start_iteration = 1
109
110
        # self.best_mean_score = 0
        # self.best_mean_score_epoch = 0
111

112
113
114
115
116
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
117

118
119
120
        self.EarlyStopping = EarlyStopping()
        self.early_stop = False

121
122
123
        if use_last_checkpoint:
            self.load_checkpoint()

124
    def train(self, train_loader, validation_loader):
125
126
127
128
129
130
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
131
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
132
133

        Returns:
134
            trained model
135
136
137
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
138
        dataloaders = {'train': train_loader, 'validation': validation_loader}
139
140

        if torch.cuda.is_available():
141
142
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
143
144
145
146
147

        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
148
        if torch.cuda.is_available():
149
150
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
151
152
        else:
            print('Device Type: {}'.format(self.device))
153
154
155
156
157
158
159
160
161
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

162
            for phase in ['train', 'validation']:
163
164
165
166
167
168
169
170
171
172
173
174
                print('-> Phase: {}'.format(phase))

                losses = []

                if phase == 'train':
                    model.train()
                    learning_rate_scheduler.step()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
175
                    y = sampled_batch[1].type(torch.FloatTensor)
176

177
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
178
179
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
180

181
                    if model.test_if_cuda:
182
183
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
184
185
186

                    y_hat = model(X)   # Forward pass

187
                    loss = self.loss_function(y_hat, y)  # Loss computation
188
189

                    if phase == 'train':
190
191
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
192
193
194
195
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

196
197
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
198

199
200
                        iteration += 1

201
                    losses.append(loss.item())                 
202
203
204
205
206
207

                    # Clear the memory

                    del X, y, y_hat, loss
                    torch.cuda.empty_cache()

208
209
                    if phase == 'validation':                

210
211
212
213
214
215
216
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

217
218
                    self.LogWriter.loss_per_epoch(losses, phase, epoch)

219
                    if phase == 'validation':
220
                        early_stop, save_checkpoint = self.EarlyStopping(np.mean(losses))
221
222
223
224
225
226
227
228
229
230
231
232
233
                        self.early_stop = early_stop
                        if save_checkpoint == True:
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
                                                        'scheduler': learning_rate_scheduler.state_dict()
                                                        },
                                                filename=os.path.join(self.experiment_directory_path, checkpoint_directory,
                                                                    'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
                                                )

234
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
235

236
237
238
239
240
241
242
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
                break
            else:
                continue
243

244
        self.LogWriter.close()
245
246
247
248
249
250
251
252
253

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
        print('****************************************************************')

254
255
256
257
258
259
260
261
262
263
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
264

265
    def load_checkpoint(self, epoch=None):
266
267
268
269
270
271
272
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
273

274
        if epoch is not None:
275
276
            checkpoint_file_path = os.path.join(
                self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
277
278
            self._checkpoint_reader(checkpoint_file_path)
        else:
279
280
            universal_path = os.path.join(
                self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
281
282
283
284
285
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
286
287
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
288
289
290
                self._checkpoint_reader(checkpoint_file_path)

            else:
291
292
                self.LogWriter.log("No Checkpoint found at {}".format(
                    os.path.join(self.experiment_directory_path, checkpoint_directory)))
293

294
295
296
297
298
299
300
301
302
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

303
304
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
305
306
307
308

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
309
        # We are not loading the model_name as we might want to pre-train a model and then use it.
310
311
        self.model.load_state_dict = checkpoint['state_dict']
        self.optimizer.load_state_dict = checkpoint['optimizer']
312
        self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
313
314

        for state in self.optimizer.state.values():
315
            for key, value in state.items():
316
317
318
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

319
320
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))