solver.py 16.7 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from fsl.utils.image.roi import roi
21
from datetime import datetime
22
from utils.common_utils import create_folder
23
from utils.data_logging_utils import LogWriter
24
from utils.early_stopping import EarlyStopping
25
from torch.optim import lr_scheduler
26
27
28

checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs
52
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
53
54

    Returns:
55
        trained model - working on this!
56
57

    """
58

59
    def __init__(self,
60
61
62
63
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
64
                 optimizer,
65
                 optimizer_arguments={},
Andrei Roibu's avatar
Andrei Roibu committed
66
                 loss_function=torch.nn.MSELoss(),
67
68
69
70
71
72
73
74
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
75
                 logs_directory='logs',
76
77
                 checkpoint_directory='checkpoints',
                 save_model_directory='saved_models',
78
                 crop_flag = False
79
                 ):
80
81
82

        self.model = model
        self.device = device
83
84
85
86
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
87
            self.MSE = torch.nn.MSELoss().cuda(device)
88
89
        else:
            self.loss_function = loss_function
90
            self.MSE = torch.nn.MSELoss()
Andrei Roibu's avatar
Andrei Roibu committed
91

92
93
94
95
        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
96

97
98
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
99
100
101
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

102
103
        self.use_last_checkpoint = use_last_checkpoint

104
        experiment_directory_path = os.path.join(
105
            experiment_directory, experiment_name)
106
        self.experiment_directory_path = experiment_directory_path
107

108
109
        self.checkpoint_directory = checkpoint_directory

110
        create_folder(experiment_directory)
111
        create_folder(experiment_directory_path)
112
        create_folder(os.path.join(
113
            experiment_directory_path, self.checkpoint_directory))
114
115
116
117

        self.start_epoch = 1
        self.start_iteration = 1

118
119
120
121
122
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
123

124
125
        self.early_stop = False

126
127
128
129
        if crop_flag == False:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
        elif crop_flag == True:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
130

131
        self.save_model_directory = save_model_directory
132
        self.final_model_output_file = experiment_name + ".pth.tar"
133

134
135
136
        self.best_score_early_stop = None
        self.counter_early_stop = 0

137
138
        if use_last_checkpoint:
            self.load_checkpoint()
139
140
141
            self.EarlyStopping = EarlyStopping(patience=2, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop)
        else:
            self.EarlyStopping = EarlyStopping(patience=2, min_delta=0)
142
143
144
            self.previous_checkpoint = None
            self.previous_loss = None
            self.previous_MSE = None
145

146
    def train(self, train_loader, validation_loader):
147
148
149
150
151
152
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
153
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
154
155

        Returns:
156
            trained model
157
158
159
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
160
        dataloaders = {'train': train_loader, 'validation': validation_loader}
161
162

        if torch.cuda.is_available():
163
164
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
165
166
167
168
169

        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
170
        if torch.cuda.is_available():
171
172
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
173
174
        else:
            print('Device Type: {}'.format(self.device))
175
176
177
178
179
180
181
182
183
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

184
            for phase in ['train', 'validation']:
185
186
187
                print('-> Phase: {}'.format(phase))

                losses = []
Andrei Roibu's avatar
Andrei Roibu committed
188
                MSEs = []
189
190
191
192
193
194
195
196

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
197
                    y = sampled_batch[1].type(torch.FloatTensor)
198

199
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
200
201
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei Roibu's avatar
Andrei Roibu committed
202
203
204

                    MNI152_T1_2mm_brain_mask = torch.unsqueeze(
                        torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
205

206
                    if model.test_if_cuda:
207
208
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
209
210
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
211

212
                    y_hat = model(X)   # Forward pass & Masking
213

214
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
215

216
217
                    loss = self.loss_function(y_hat, y)  # Loss computation
                    # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
Andrei Roibu's avatar
Andrei Roibu committed
218
219
220
221

                    # We also calculate a separate MSE for cost function comparison!
                    MSE = self.MSE(y_hat, y)
                    MSEs.append(MSE.item())
222
223

                    if phase == 'train':
224
225
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
226
227
228
229
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

230
231
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
232

233
234
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
235
                    losses.append(loss.item())
236
237
238

                    # Clear the memory

Andrei Roibu's avatar
Andrei Roibu committed
239
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
240
241
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
242
                    if phase == 'validation':
243

244
245
246
247
248
249
250
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

251
252
                    if phase == 'train':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch)
Andrei Roibu's avatar
Andrei Roibu committed
253
                        self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
254
                    elif phase == 'validation':
Andrei Roibu's avatar
Andrei Roibu committed
255
                        self.LogWriter.loss_per_epoch(
256
257
                            losses, phase, epoch, previous_loss=self.previous_loss)
                        self.previous_loss = np.mean(losses)
Andrei Roibu's avatar
Andrei Roibu committed
258
                        self.LogWriter.MSE_per_epoch(
259
260
                            MSEs, phase, epoch, previous_loss=self.previous_MSE)
                        self.previous_MSE = np.mean(MSEs)
261

262
                    if phase == 'validation':
263
                        early_stop, save_checkpoint, best_score_early_stop, counter_early_stop = self.EarlyStopping(
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
264
                            np.mean(losses))
265
                        self.early_stop = early_stop
266
267
                        self.best_score_early_stop = best_score_early_stop
                        self.counter_early_stop = counter_early_stop
268
                        if save_checkpoint == True:
269
                            validation_loss = np.mean(losses)
270
271
                            checkpoint_name = os.path.join(
                                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
272
273
274
275
276
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
277
278
                                                        'scheduler': learning_rate_scheduler.state_dict(),
                                                        'best_score_early_stop': self.best_score_early_stop,
279
280
281
282
                                                        'counter_early_stop': self.counter_early_stop,
                                                        'previous_checkpoint': self.previous_checkpoint,
                                                        'previous_loss': self.previous_loss,
                                                        'previous_MSE': self.previous_MSE,
283
                                                        },
284
                                                 filename=checkpoint_name
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
285
                                                 )
286

287
288
289
                            if self.previous_checkpoint != None:
                                os.remove(self.previous_checkpoint)
                                self.previous_checkpoint = checkpoint_name
290
                            else:
291
                                self.previous_checkpoint = checkpoint_name
292

293
294
295
                if phase == 'train':
                    learning_rate_scheduler.step()

296
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
297

298
299
300
301
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
302
                self.load_checkpoint()
303
304
305
                break
            else:
                continue
306

307
308
309
310
311
312
313
        model_output_path = os.path.join(
            self.save_model_directory, self.final_model_output_file)

        create_folder(self.save_model_directory)

        model.save(model_output_path)

314
        self.LogWriter.close()
315
316
317
318
319
320
321

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
322
        print('Final Model Saved in: {}'.format(model_output_path))
323
324
        print('****************************************************************')

325
326
327
        if self.start_epoch >= self.number_epochs+1:
            validation_loss = None

328
329
        return validation_loss

330
331
332
333
334
335
336
337
338
339
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
340

341
    def load_checkpoint(self, epoch=None):
342
343
344
345
346
347
348
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
349

350
        if epoch is not None:
351
            checkpoint_file_path = os.path.join(
352
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
353
354
            self._checkpoint_reader(checkpoint_file_path)
        else:
355
            universal_path = os.path.join(
356
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
357
358
359
360
361
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
362
363
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
364
365
366
                self._checkpoint_reader(checkpoint_file_path)

            else:
367
                self.LogWriter.log("No Checkpoint found at {}".format(
368
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
369

370
371
372
373
374
375
376
377
378
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

379
380
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
381
382
383
384

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
385
        # We are not loading the model_name as we might want to pre-train a model and then use it.
Andrei Roibu's avatar
Andrei Roibu committed
386
387
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
388
389
        self.best_score_early_stop = checkpoint['best_score_early_stop']
        self.counter_early_stop = checkpoint['counter_early_stop']
390
391
392
        self.previous_checkpoint = checkpoint['previous_checkpoint']
        self.previous_loss = checkpoint['previous_loss']
        self.previous_MSE = checkpoint['previous_MSE']
393

394
        for state in self.optimizer.state.values():
395
            for key, value in state.items():
396
397
398
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

Andrei Roibu's avatar
Andrei Roibu committed
399
        self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler'])
400
401
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))