solver.py 16.7 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from fsl.utils.image.roi import roi
21
from datetime import datetime
22
from utils.common_utils import create_folder
23
from utils.data_logging_utils import LogWriter
24
from utils.early_stopping import EarlyStopping
25
from torch.optim import lr_scheduler
26
27
28

checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs
52
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
53
54

    Returns:
55
        trained model - working on this!
56
57

    """
58

59
    def __init__(self,
60
61
62
63
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
64
                 optimizer,
65
                 optimizer_arguments={},
Andrei Roibu's avatar
Andrei Roibu committed
66
                 loss_function=torch.nn.MSELoss(),
67
68
69
70
71
72
73
74
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
75
                 logs_directory='logs',
76
77
                 checkpoint_directory='checkpoints',
                 save_model_directory='saved_models',
78
                 crop_flag = False
79
                 ):
80
81
82

        self.model = model
        self.device = device
83
84
85
86
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
87
            self.MSE = torch.nn.MSELoss().cuda(device)
88
89
        else:
            self.loss_function = loss_function
90
            self.MSE = torch.nn.MSELoss()
Andrei Roibu's avatar
Andrei Roibu committed
91

92
93
94
95
        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
96

97
98
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
99
100
101
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

102
103
        self.use_last_checkpoint = use_last_checkpoint

104
        experiment_directory_path = os.path.join(
105
            experiment_directory, experiment_name)
106
        self.experiment_directory_path = experiment_directory_path
107

108
109
        self.checkpoint_directory = checkpoint_directory

110
        create_folder(experiment_directory)
111
        create_folder(experiment_directory_path)
112
        create_folder(os.path.join(
113
            experiment_directory_path, self.checkpoint_directory))
114
115
116
117

        self.start_epoch = 1
        self.start_iteration = 1

118
119
120
121
122
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
123

124
125
        self.early_stop = False

126
127
128
129
        if crop_flag == False:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
        elif crop_flag == True:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
130

131
        self.save_model_directory = save_model_directory
132
        self.final_model_output_file = experiment_name + ".pth.tar"
133

134
135
        self.best_score_early_stop = None
        self.counter_early_stop = 0
136
137
138
        self.previous_checkpoint = None
        self.previous_loss = None
        self.previous_MSE = None
139

140
141
        if use_last_checkpoint:
            self.load_checkpoint()
142
143
144
            self.EarlyStopping = EarlyStopping(patience=2, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop)
        else:
            self.EarlyStopping = EarlyStopping(patience=2, min_delta=0)
145

146

147
    def train(self, train_loader, validation_loader):
148
149
150
151
152
153
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
154
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
155
156

        Returns:
157
            trained model
158
159
160
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
161
        dataloaders = {'train': train_loader, 'validation': validation_loader}
162
163

        if torch.cuda.is_available():
164
165
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
166
167
168
169
170

        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
171
        if torch.cuda.is_available():
172
173
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
174
175
        else:
            print('Device Type: {}'.format(self.device))
176
177
178
179
180
181
182
183
184
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

185
            for phase in ['train', 'validation']:
186
187
188
                print('-> Phase: {}'.format(phase))

                losses = []
Andrei Roibu's avatar
Andrei Roibu committed
189
                MSEs = []
190
191
192
193
194
195
196
197

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
198
                    y = sampled_batch[1].type(torch.FloatTensor)
199

200
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
201
202
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei Roibu's avatar
Andrei Roibu committed
203
204
205

                    MNI152_T1_2mm_brain_mask = torch.unsqueeze(
                        torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
206

207
                    if model.test_if_cuda:
208
209
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
210
211
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
212

213
                    y_hat = model(X)   # Forward pass & Masking
214

215
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
216

217
218
                    loss = self.loss_function(y_hat, y)  # Loss computation
                    # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
Andrei Roibu's avatar
Andrei Roibu committed
219
220
221
222

                    # We also calculate a separate MSE for cost function comparison!
                    MSE = self.MSE(y_hat, y)
                    MSEs.append(MSE.item())
223
224

                    if phase == 'train':
225
226
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
227
228
229
230
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

231
232
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
233

234
235
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
236
                    losses.append(loss.item())
237
238
239

                    # Clear the memory

Andrei Roibu's avatar
Andrei Roibu committed
240
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
241
242
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
243
                    if phase == 'validation':
244

245
246
247
248
249
250
251
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

252
253
                    if phase == 'train':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch)
Andrei Roibu's avatar
Andrei Roibu committed
254
                        self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
255
                    elif phase == 'validation':
Andrei Roibu's avatar
Andrei Roibu committed
256
                        self.LogWriter.loss_per_epoch(
257
258
                            losses, phase, epoch, previous_loss=self.previous_loss)
                        self.previous_loss = np.mean(losses)
Andrei Roibu's avatar
Andrei Roibu committed
259
                        self.LogWriter.MSE_per_epoch(
260
261
                            MSEs, phase, epoch, previous_loss=self.previous_MSE)
                        self.previous_MSE = np.mean(MSEs)
262

263
                    if phase == 'validation':
264
                        early_stop, save_checkpoint, best_score_early_stop, counter_early_stop = self.EarlyStopping(
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
265
                            np.mean(losses))
266
                        self.early_stop = early_stop
267
268
                        self.best_score_early_stop = best_score_early_stop
                        self.counter_early_stop = counter_early_stop
269
                        if save_checkpoint == True:
270
                            validation_loss = np.mean(losses)
271
272
                            checkpoint_name = os.path.join(
                                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
273
274
275
276
277
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
278
279
                                                        'scheduler': learning_rate_scheduler.state_dict(),
                                                        'best_score_early_stop': self.best_score_early_stop,
280
281
282
283
                                                        'counter_early_stop': self.counter_early_stop,
                                                        'previous_checkpoint': self.previous_checkpoint,
                                                        'previous_loss': self.previous_loss,
                                                        'previous_MSE': self.previous_MSE,
284
                                                        },
285
                                                 filename=checkpoint_name
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
286
                                                 )
287

288
289
290
                            if self.previous_checkpoint != None:
                                os.remove(self.previous_checkpoint)
                                self.previous_checkpoint = checkpoint_name
291
                            else:
292
                                self.previous_checkpoint = checkpoint_name
293

294
295
296
                if phase == 'train':
                    learning_rate_scheduler.step()

297
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
298

299
300
301
302
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
303
                self.load_checkpoint()
304
305
306
                break
            else:
                continue
307

308
309
310
311
312
313
314
        model_output_path = os.path.join(
            self.save_model_directory, self.final_model_output_file)

        create_folder(self.save_model_directory)

        model.save(model_output_path)

315
        self.LogWriter.close()
316
317
318
319
320
321
322

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
323
        print('Final Model Saved in: {}'.format(model_output_path))
324
325
        print('****************************************************************')

326
327
328
        if self.start_epoch >= self.number_epochs+1:
            validation_loss = None

329
330
        return validation_loss

331
332
333
334
335
336
337
338
339
340
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
341

342
    def load_checkpoint(self, epoch=None):
343
344
345
346
347
348
349
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
350

351
        if epoch is not None:
352
            checkpoint_file_path = os.path.join(
353
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
354
355
            self._checkpoint_reader(checkpoint_file_path)
        else:
356
            universal_path = os.path.join(
357
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
358
359
360
361
362
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
363
364
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
365
366
367
                self._checkpoint_reader(checkpoint_file_path)

            else:
368
                self.LogWriter.log("No Checkpoint found at {}".format(
369
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
370

371
372
373
374
375
376
377
378
379
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

380
381
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
382
383
384
385

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
386
        # We are not loading the model_name as we might want to pre-train a model and then use it.
Andrei Roibu's avatar
Andrei Roibu committed
387
388
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
389
390
        self.best_score_early_stop = checkpoint['best_score_early_stop']
        self.counter_early_stop = checkpoint['counter_early_stop']
391
392
393
        self.previous_checkpoint = checkpoint['previous_checkpoint']
        self.previous_loss = checkpoint['previous_loss']
        self.previous_MSE = checkpoint['previous_MSE']
394

395
        for state in self.optimizer.state.values():
396
            for key, value in state.items():
397
398
399
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

Andrei Roibu's avatar
Andrei Roibu committed
400
        self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler'])
401
402
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))