solver.py 13.3 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from datetime import datetime
21
22
from utils.losses import MSELoss
from utils.data_utils import create_folder
23
from utils.data_logging_utils import LogWriter
24
from utils.early_stopping import EarlyStopping
25
from torch.optim import lr_scheduler
26
27
28

checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs

    Returns:
54
        trained model - working on this!
55
56

    """
57

58
    def __init__(self,
59
60
61
62
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
63
                 optimizer,
64
65
66
67
68
69
70
71
72
73
                 optimizer_arguments={},
                 loss_function=MSELoss(),
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
74
                 logs_directory='logs',
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
75
                 checkpoint_directory='checkpoints'
76
                 ):
77
78
79

        self.model = model
        self.device = device
80
81
82
83
84
85
86
87
88
89
90
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
        else:
            self.loss_function = loss_function

        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
91

92
93
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
94
95
96
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

97
98
        self.use_last_checkpoint = use_last_checkpoint

99
        experiment_directory_path = os.path.join(
100
            experiment_directory, experiment_name)
101
        self.experiment_directory_path = experiment_directory_path
102

103
104
        self.checkpoint_directory = checkpoint_directory

105
        create_folder(experiment_directory)
106
        create_folder(experiment_directory_path)
107
        create_folder(os.path.join(
108
            experiment_directory_path, self.checkpoint_directory))
109
110
111
112

        self.start_epoch = 1
        self.start_iteration = 1

113
114
115
116
117
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
118

119
120
121
        self.EarlyStopping = EarlyStopping()
        self.early_stop = False

122
123
124
        if use_last_checkpoint:
            self.load_checkpoint()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
125
126
        self.MNI152_T1_2mm_brain_mask = torch.from_numpy(
            Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
127

128
    def train(self, train_loader, validation_loader):
129
130
131
132
133
134
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
135
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
136
137

        Returns:
138
            trained model
139
140
141
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
142
        dataloaders = {'train': train_loader, 'validation': validation_loader}
143
144

        if torch.cuda.is_available():
145
146
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
147
148
149
150
151

        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
152
        if torch.cuda.is_available():
153
154
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
155
156
        else:
            print('Device Type: {}'.format(self.device))
157
158
159
160
161
162
163
164
165
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

166
            for phase in ['train', 'validation']:
167
168
169
170
171
172
173
174
175
176
177
                print('-> Phase: {}'.format(phase))

                losses = []

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
178
                    y = sampled_batch[1].type(torch.FloatTensor)
179

180
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
181
182
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
183

184
                    MNI152_T1_2mm_brain_mask = self.MNI152_T1_2mm_brain_mask
185

186
                    if model.test_if_cuda:
187
188
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
189
190
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
191

192
                    y_hat = model(X)   # Forward pass & Masking
193

194
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
195

196
                    loss = self.loss_function(y_hat, y)  # Loss computation
197
198

                    if phase == 'train':
199
200
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
201
202
203
204
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

205
206
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
207

208
209
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
210
                    losses.append(loss.item())
211
212
213

                    # Clear the memory

214
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask
215
216
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
217
                    if phase == 'validation':
218

219
220
221
222
223
224
225
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

226
227
                    self.LogWriter.loss_per_epoch(losses, phase, epoch)

228
                    if phase == 'validation':
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
229
230
                        early_stop, save_checkpoint = self.EarlyStopping(
                            np.mean(losses))
231
232
                        self.early_stop = early_stop
                        if save_checkpoint == True:
233
                            validation_loss = np.mean(losses)
234
235
236
237
238
239
240
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
                                                        'scheduler': learning_rate_scheduler.state_dict()
                                                        },
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
241
242
243
                                                 filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory,
                                                                       'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
                                                 )
244
245
246
                            # if epoch != self.start_epoch:
                            #     os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
                            #                            'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
247

248
249
250
                if phase == 'train':
                    learning_rate_scheduler.step()

251
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
252

253
254
255
256
257
258
259
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
                break
            else:
                continue
260

261
        self.LogWriter.close()
262
263
264
265
266
267
268
269
270

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
        print('****************************************************************')

271
272
        return validation_loss

273
274
275
276
277
278
279
280
281
282
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
283

284
    def load_checkpoint(self, epoch=None):
285
286
287
288
289
290
291
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
292

293
        if epoch is not None:
294
            checkpoint_file_path = os.path.join(
295
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
296
297
            self._checkpoint_reader(checkpoint_file_path)
        else:
298
            universal_path = os.path.join(
299
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
300
301
302
303
304
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
305
306
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
307
308
309
                self._checkpoint_reader(checkpoint_file_path)

            else:
310
                self.LogWriter.log("No Checkpoint found at {}".format(
311
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
312

313
314
315
316
317
318
319
320
321
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

322
323
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
324
325
326
327

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
328
        # We are not loading the model_name as we might want to pre-train a model and then use it.
329
330
        self.model.load_state_dict = checkpoint['state_dict']
        self.optimizer.load_state_dict = checkpoint['optimizer']
331
        self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
332
333

        for state in self.optimizer.state.values():
334
            for key, value in state.items():
335
336
337
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

338
339
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))