solver.py 13 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from datetime import datetime
20
21
from utils.losses import MSELoss
from utils.data_utils import create_folder
22
from utils.data_logging_utils import LogWriter
23
from utils.early_stopping import EarlyStopping
24
from torch.optim import lr_scheduler
25
26
27

checkpoint_extension = 'path.tar'

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs

    Returns:
53
        trained model - working on this!
54
55

    """
56

57
    def __init__(self,
58
59
60
61
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
62
                 optimizer,
63
64
65
66
67
68
69
70
71
72
                 optimizer_arguments={},
                 loss_function=MSELoss(),
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
73
                 logs_directory='logs',
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
74
                 checkpoint_directory='checkpoints'
75
                 ):
76
77
78

        self.model = model
        self.device = device
79
80
81
82
83
84
85
86
87
88
89
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
        else:
            self.loss_function = loss_function

        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
90

91
92
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
93
94
95
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

96
97
        self.use_last_checkpoint = use_last_checkpoint

98
        experiment_directory_path = os.path.join(
99
            experiment_directory, experiment_name)
100
        self.experiment_directory_path = experiment_directory_path
101

102
103
        self.checkpoint_directory = checkpoint_directory

104
        create_folder(experiment_directory)
105
        create_folder(experiment_directory_path)
106
        create_folder(os.path.join(
107
            experiment_directory_path, self.checkpoint_directory))
108
109
110

        self.start_epoch = 1
        self.start_iteration = 1
111
112
        # self.best_mean_score = 0
        # self.best_mean_score_epoch = 0
113

114
115
116
117
118
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
119

120
121
122
        self.EarlyStopping = EarlyStopping()
        self.early_stop = False

123
124
125
        if use_last_checkpoint:
            self.load_checkpoint()

126
    def train(self, train_loader, validation_loader):
127
128
129
130
131
132
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
133
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
134
135

        Returns:
136
            trained model
137
138
139
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
140
        dataloaders = {'train': train_loader, 'validation': validation_loader}
141
142

        if torch.cuda.is_available():
143
144
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
145
146
147
148
149

        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
150
        if torch.cuda.is_available():
151
152
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
153
154
        else:
            print('Device Type: {}'.format(self.device))
155
156
157
158
159
160
161
162
163
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

164
            for phase in ['train', 'validation']:
165
166
167
168
169
170
171
172
173
174
175
                print('-> Phase: {}'.format(phase))

                losses = []

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
176
177
178
179

                    # X = ( X - X.min() ) / ( X.max() - X.min() )
                    # X = ( X - X.mean() ) / X.std()

180
                    y = sampled_batch[1].type(torch.FloatTensor)
181

182
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
183
184
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
185

186
                    if model.test_if_cuda:
187
188
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
189
190
191

                    y_hat = model(X)   # Forward pass

192
                    loss = self.loss_function(y_hat, y)  # Loss computation
193
194

                    if phase == 'train':
195
196
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
197
198
199
200
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

201
202
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
203

204
205
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
206
                    losses.append(loss.item())
207
208
209
210
211
212

                    # Clear the memory

                    del X, y, y_hat, loss
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
213
                    if phase == 'validation':
214

215
216
217
218
219
220
221
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

222
223
                    self.LogWriter.loss_per_epoch(losses, phase, epoch)

224
                    if phase == 'validation':
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
225
226
                        early_stop, save_checkpoint = self.EarlyStopping(
                            np.mean(losses))
227
228
                        self.early_stop = early_stop
                        if save_checkpoint == True:
229
                            validation_loss = np.mean(losses)
230
231
232
233
234
235
236
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
                                                        'scheduler': learning_rate_scheduler.state_dict()
                                                        },
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
237
238
239
                                                 filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory,
                                                                       'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
                                                 )
240
                            if epoch != self.start_epoch:
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
241
242
                                os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
                                                       'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
243

244
245
246
                if phase == 'train':
                    learning_rate_scheduler.step()

247
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
248

249
250
251
252
253
254
255
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
                break
            else:
                continue
256

257
        self.LogWriter.close()
258
259
260
261
262
263
264
265
266

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
        print('****************************************************************')

267
268
        return validation_loss

269
270
271
272
273
274
275
276
277
278
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
279

280
    def load_checkpoint(self, epoch=None):
281
282
283
284
285
286
287
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
288

289
        if epoch is not None:
290
            checkpoint_file_path = os.path.join(
291
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
292
293
            self._checkpoint_reader(checkpoint_file_path)
        else:
294
            universal_path = os.path.join(
295
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
296
297
298
299
300
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
301
302
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
303
304
305
                self._checkpoint_reader(checkpoint_file_path)

            else:
306
                self.LogWriter.log("No Checkpoint found at {}".format(
307
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
308

309
310
311
312
313
314
315
316
317
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

318
319
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
320
321
322
323

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
324
        # We are not loading the model_name as we might want to pre-train a model and then use it.
325
326
        self.model.load_state_dict = checkpoint['state_dict']
        self.optimizer.load_state_dict = checkpoint['optimizer']
327
        self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
328
329

        for state in self.optimizer.state.values():
330
            for key, value in state.items():
331
332
333
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

334
335
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))