solver.py 17 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from fsl.utils.image.roi import roi
21
from datetime import datetime
22
from utils.common_utils import create_folder
23
from utils.data_logging_utils import LogWriter
24
from utils.early_stopping import EarlyStopping
25
from torch.optim import lr_scheduler
26
27
28

checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs
52
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
53
54

    Returns:
55
        trained model - working on this!
56
57

    """
58

59
    def __init__(self,
60
61
62
63
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
64
                 optimizer,
65
                 optimizer_arguments={},
Andrei Roibu's avatar
Andrei Roibu committed
66
                 loss_function=torch.nn.MSELoss(),
67
68
69
70
71
72
73
74
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
75
                 logs_directory='logs',
76
77
                 checkpoint_directory='checkpoints',
                 save_model_directory='saved_models',
78
                 crop_flag = False
79
                 ):
80
81
82

        self.model = model
        self.device = device
83
84
85
86
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
87
            self.MSE = torch.nn.MSELoss().cuda(device)
88
89
        else:
            self.loss_function = loss_function
90
            self.MSE = torch.nn.MSELoss()
Andrei Roibu's avatar
Andrei Roibu committed
91

92
93
94
95
        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
96

97
98
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
99
100
101
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

102
103
        self.use_last_checkpoint = use_last_checkpoint

104
        experiment_directory_path = os.path.join(
105
            experiment_directory, experiment_name)
106
        self.experiment_directory_path = experiment_directory_path
107

108
109
        self.checkpoint_directory = checkpoint_directory

110
        create_folder(experiment_directory)
111
        create_folder(experiment_directory_path)
112
        create_folder(os.path.join(
113
            experiment_directory_path, self.checkpoint_directory))
114
115
116
117

        self.start_epoch = 1
        self.start_iteration = 1

118
119
120
121
122
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
123

124
125
        self.early_stop = False

126
127
128
129
        if crop_flag == False:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
        elif crop_flag == True:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
130

131
        self.save_model_directory = save_model_directory
132
        self.final_model_output_file = experiment_name + ".pth.tar"
133

134
135
        self.best_score_early_stop = None
        self.counter_early_stop = 0
136
137
        self.previous_loss = None
        self.previous_MSE = None
138
        self.valid_epoch = None
139

140
141
        if use_last_checkpoint:
            self.load_checkpoint()
Andrei Roibu's avatar
Andrei Roibu committed
142
            self.EarlyStopping = EarlyStopping(patience=5, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop)
143
        else:
Andrei Roibu's avatar
Andrei Roibu committed
144
            self.EarlyStopping = EarlyStopping(patience=5, min_delta=0)
145

146
    def train(self, train_loader, validation_loader):
147
148
149
150
151
152
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
153
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
154
155

        Returns:
156
            trained model
157
158
159
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
160
        dataloaders = {'train': train_loader, 'validation': validation_loader}
161
162

        if torch.cuda.is_available():
163
164
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
165
166
167
168
169

        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
170
        if torch.cuda.is_available():
171
172
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
173
174
        else:
            print('Device Type: {}'.format(self.device))
175
176
177
178
179
180
181
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
182
183
184
185
186

            if self.early_stop == True:
                print("ATTENTION!: Training stopped due to previous early stop flag!")
                break

187
188
            print("Epoch {}/{}".format(epoch, self.number_epochs))

189
            for phase in ['train', 'validation']:
190
191
192
                print('-> Phase: {}'.format(phase))

                losses = []
Andrei Roibu's avatar
Andrei Roibu committed
193
                MSEs = []
194
195
196
197
198
199
200
201

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
202
                    y = sampled_batch[1].type(torch.FloatTensor)
203

204
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
205
206
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei Roibu's avatar
Andrei Roibu committed
207
208
209

                    MNI152_T1_2mm_brain_mask = torch.unsqueeze(
                        torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
210

211
                    if model.test_if_cuda:
212
213
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
214
215
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
216

217
                    y_hat = model(X)   # Forward pass & Masking
218

219
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
220

221
222
                    loss = self.loss_function(y_hat, y)  # Loss computation
                    # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
Andrei Roibu's avatar
Andrei Roibu committed
223
224
225
226

                    # We also calculate a separate MSE for cost function comparison!
                    MSE = self.MSE(y_hat, y)
                    MSEs.append(MSE.item())
227
228

                    if phase == 'train':
229
230
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
231
232
233
234
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

235
236
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
237

238
239
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
240
                    losses.append(loss.item())
241
242
243

                    # Clear the memory

Andrei Roibu's avatar
Andrei Roibu committed
244
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
245
246
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
247
                    if phase == 'validation':
248

249
250
251
252
253
254
255
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

256
257
                    if phase == 'train':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch)
Andrei Roibu's avatar
Andrei Roibu committed
258
                        self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
259
                    elif phase == 'validation':
Andrei Roibu's avatar
Andrei Roibu committed
260
                        self.LogWriter.loss_per_epoch(
261
262
                            losses, phase, epoch, previous_loss=self.previous_loss)
                        self.previous_loss = np.mean(losses)
Andrei Roibu's avatar
Andrei Roibu committed
263
                        self.LogWriter.MSE_per_epoch(
264
265
                            MSEs, phase, epoch, previous_loss=self.previous_MSE)
                        self.previous_MSE = np.mean(MSEs)
266

267
                    if phase == 'validation':
268
269
                        early_stop, best_score_early_stop, counter_early_stop = self.EarlyStopping(np.mean(losses))

270
                        self.early_stop = early_stop
271
272
                        self.best_score_early_stop = best_score_early_stop
                        self.counter_early_stop = counter_early_stop
273

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                        checkpoint_name = os.path.join(
                            self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)

                        if self.counter_early_stop == 0:
                            self.valid_epoch = epoch

                        self.save_checkpoint(state={'epoch': epoch + 1,
                                                    'start_iteration': iteration + 1,
                                                    'arch': self.model_name,
                                                    'state_dict': model.state_dict(),
                                                    'optimizer': optimizer.state_dict(),
                                                    'scheduler': learning_rate_scheduler.state_dict(),
                                                    'best_score_early_stop': self.best_score_early_stop,
                                                    'counter_early_stop': self.counter_early_stop,
                                                    'previous_loss': self.previous_loss,
                                                    'previous_MSE': self.previous_MSE,
290
291
                                                    'early_stop': self.early_stop,
                                                    'valid_epoch': self.valid_epoch
292
293
294
                                                    },
                                             filename=checkpoint_name
                                             )
295
296
297
                if phase == 'train':
                    learning_rate_scheduler.step()

298
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
299

300
301
302
303
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
304
                self.load_checkpoint(epoch=self.valid_epoch)
305
306
307
                break
            else:
                continue
308

309
310
311
312
313
314
315
316
317
318
319
320
321
322
        if self.early_stop == True:
            
            self.LogWriter.close()

            print('----------------------------------------')
            print('NO TRAINING DONE TO PREVENT OVERFITTING!')
            print('=====================')
            end_time = datetime.now()
            print('Completed At: {}'.format(end_time))
            print('Training Duration: {}'.format(end_time - start_time))
            print('****************************************************************')
        else:
            model_output_path = os.path.join(
                self.save_model_directory, self.final_model_output_file)
323

324
            create_folder(self.save_model_directory)
325

326
            model.save(model_output_path)
327

328
            self.LogWriter.close()
329

330
331
332
333
334
335
336
337
            print('----------------------------------------')
            print('TRAINING IS COMPLETE!')
            print('=====================')
            end_time = datetime.now()
            print('Completed At: {}'.format(end_time))
            print('Training Duration: {}'.format(end_time - start_time))
            print('Final Model Saved in: {}'.format(model_output_path))
            print('****************************************************************')
338

339
340
341
342
343
344
345
346
347
348
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
349

350
    def load_checkpoint(self, epoch=None):
351
352
353
354
355
356
357
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
358

359
        if epoch is not None:
360
            checkpoint_file_path = os.path.join(
361
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
362
363
            self._checkpoint_reader(checkpoint_file_path)
        else:
364
            universal_path = os.path.join(
365
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
366
367
368
369
370
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
371
372
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
373
374
375
                self._checkpoint_reader(checkpoint_file_path)

            else:
376
                self.LogWriter.log("No Checkpoint found at {}".format(
377
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
378

379
380
381
382
383
384
385
386
387
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

388
389
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
390
391
392
393

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
394
        # We are not loading the model_name as we might want to pre-train a model and then use it.
Andrei Roibu's avatar
Andrei Roibu committed
395
396
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
397
398
        self.best_score_early_stop = checkpoint['best_score_early_stop']
        self.counter_early_stop = checkpoint['counter_early_stop']
399
400
        self.previous_loss = checkpoint['previous_loss']
        self.previous_MSE = checkpoint['previous_MSE']
401
        self.early_stop = checkpoint['early_stop']
402
        self.valid_epoch = checkpoint['valid_epoch']
403

404
        for state in self.optimizer.state.values():
405
            for key, value in state.items():
406
407
408
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

Andrei Roibu's avatar
Andrei Roibu committed
409
        self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler'])
410
411
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))