solver.py 14.6 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from datetime import datetime
21
22
from utils.losses import MSELoss
from utils.data_utils import create_folder
23
from utils.data_logging_utils import LogWriter
24
from utils.early_stopping import EarlyStopping
25
from torch.optim import lr_scheduler
26
27
28

checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs

    Returns:
54
        trained model - working on this!
55
56

    """
57

58
    def __init__(self,
59
60
61
62
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
63
                 optimizer,
64
65
66
67
68
69
70
71
72
73
                 optimizer_arguments={},
                 loss_function=MSELoss(),
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
74
                 logs_directory='logs',
75
76
77
                 checkpoint_directory='checkpoints',
                 save_model_directory='saved_models',
                 final_model_output_file='finetuned_alldata.pth.tar'
78
                 ):
79
80
81

        self.model = model
        self.device = device
82
83
84
85
86
87
88
89
90
91
92
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
        else:
            self.loss_function = loss_function

        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
93

94
95
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
96
97
98
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

99
100
        self.use_last_checkpoint = use_last_checkpoint

101
        experiment_directory_path = os.path.join(
102
            experiment_directory, experiment_name)
103
        self.experiment_directory_path = experiment_directory_path
104

105
106
        self.checkpoint_directory = checkpoint_directory

107
        create_folder(experiment_directory)
108
        create_folder(experiment_directory_path)
109
        create_folder(os.path.join(
110
            experiment_directory_path, self.checkpoint_directory))
111
112
113
114

        self.start_epoch = 1
        self.start_iteration = 1

115
116
117
118
119
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
120

121
        self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
122
123
        self.early_stop = False

124
125
126
        if use_last_checkpoint:
            self.load_checkpoint()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
127
128
        self.MNI152_T1_2mm_brain_mask = torch.from_numpy(
            Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
129

130
131
132
        self.save_model_directory = save_model_directory
        self.final_model_output_file = final_model_output_file

133
    def train(self, train_loader, validation_loader):
134
135
136
137
138
139
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
140
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
141
142

        Returns:
143
            trained model
144
145
146
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
147
        dataloaders = {'train': train_loader, 'validation': validation_loader}
148
149

        if torch.cuda.is_available():
150
151
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
152

153
        previous_checkpoint = None
154
        previous_loss = None
155

156
157
158
159
        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
160
        if torch.cuda.is_available():
161
162
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
163
164
        else:
            print('Device Type: {}'.format(self.device))
165
166
167
168
169
170
171
172
173
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

174
            for phase in ['train', 'validation']:
175
176
177
178
179
180
181
182
183
184
185
                print('-> Phase: {}'.format(phase))

                losses = []

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
186
                    y = sampled_batch[1].type(torch.FloatTensor)
187

188
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
189
190
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
191

192
193
194
195
                    print('X range:', torch.min(X), torch.max(X))
                    print('y range:', torch.min(y), torch.max(y))
                    
                    MNI152_T1_2mm_brain_mask = torch.unsqueeze(torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
196

197
                    if model.test_if_cuda:
198
199
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
200
201
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
202

203
                    y_hat = model(X)   # Forward pass & Masking
204

205
206
                    print('y_hat range:', torch.min(y_hat), torch.max(y_hat))

207
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
208

209
210
                    print('y_hat masked range:', torch.min(y_hat), torch.max(y_hat))

211
                    loss = self.loss_function(y_hat, y)  # Loss computation
212
213

                    if phase == 'train':
214
215
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
216
217
218
219
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

220
221
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
222

223
224
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
225
                    losses.append(loss.item())
226
227
228

                    # Clear the memory

229
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask
230
231
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
232
                    if phase == 'validation':
233

234
235
236
237
238
239
240
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

241
242
243
244
245
                    if phase == 'train':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch)
                    elif phase == 'validation':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch, previous_loss=previous_loss)
                        previous_loss = np.mean(losses)
246

247
                    if phase == 'validation':
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
248
249
                        early_stop, save_checkpoint = self.EarlyStopping(
                            np.mean(losses))
250
251
                        self.early_stop = early_stop
                        if save_checkpoint == True:
252
                            validation_loss = np.mean(losses)
253
254
                            checkpoint_name = os.path.join(
                                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
255
256
257
258
259
260
261
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
                                                        'scheduler': learning_rate_scheduler.state_dict()
                                                        },
262
                                                 filename=checkpoint_name
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
263
                                                 )
264
265
266
267
268
269

                            if previous_checkpoint != None:
                                os.remove(previous_checkpoint)
                                previous_checkpoint = checkpoint_name
                            else:
                                previous_checkpoint = checkpoint_name
270

271
272
273
                if phase == 'train':
                    learning_rate_scheduler.step()

274
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
275

276
277
278
279
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
280
                self.load_checkpoint()
281
282
283
                break
            else:
                continue
284

285
286
287
288
289
290
291
        model_output_path = os.path.join(
            self.save_model_directory, self.final_model_output_file)

        create_folder(self.save_model_directory)

        model.save(model_output_path)

292
        self.LogWriter.close()
293
294
295
296
297
298
299

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
300
        print('Final Model Saved in: {}'.format(model_output_path))
301
302
        print('****************************************************************')

303
304
        return validation_loss

305
306
307
308
309
310
311
312
313
314
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
315

316
    def load_checkpoint(self, epoch=None):
317
318
319
320
321
322
323
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
324

325
        if epoch is not None:
326
            checkpoint_file_path = os.path.join(
327
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
328
329
            self._checkpoint_reader(checkpoint_file_path)
        else:
330
            universal_path = os.path.join(
331
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
332
333
334
335
336
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
337
338
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
339
340
341
                self._checkpoint_reader(checkpoint_file_path)

            else:
342
                self.LogWriter.log("No Checkpoint found at {}".format(
343
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
344

345
346
347
348
349
350
351
352
353
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

354
355
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
356
357
358
359

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
360
        # We are not loading the model_name as we might want to pre-train a model and then use it.
361
362
        self.model.load_state_dict = checkpoint['state_dict']
        self.optimizer.load_state_dict = checkpoint['optimizer']
363
        self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
364
365

        for state in self.optimizer.state.values():
366
            for key, value in state.items():
367
368
369
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

370
371
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))