solver.py 16.3 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from fsl.utils.image.roi import roi
21
from datetime import datetime
22
from utils.common_utils import create_folder
23
from utils.data_logging_utils import LogWriter
24
from utils.early_stopping import EarlyStopping
25
from torch.optim import lr_scheduler
26
27
28

checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs
52
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
53
54

    Returns:
55
        trained model - working on this!
56
57

    """
58

59
    def __init__(self,
60
61
62
63
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
64
                 optimizer,
65
                 optimizer_arguments={},
Andrei Roibu's avatar
Andrei Roibu committed
66
                 loss_function=torch.nn.MSELoss(),
67
68
69
70
71
72
73
74
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
75
                 logs_directory='logs',
76
77
                 checkpoint_directory='checkpoints',
                 save_model_directory='saved_models',
78
                 crop_flag = False
79
                 ):
80
81
82

        self.model = model
        self.device = device
83
84
85
86
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
87
            self.MSE = torch.nn.MSELoss().cuda(device)
88
89
        else:
            self.loss_function = loss_function
90
            self.MSE = torch.nn.MSELoss()
Andrei Roibu's avatar
Andrei Roibu committed
91

92
93
94
95
        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
96

97
98
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
99
100
101
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

102
103
        self.use_last_checkpoint = use_last_checkpoint

104
        experiment_directory_path = os.path.join(
105
            experiment_directory, experiment_name)
106
        self.experiment_directory_path = experiment_directory_path
107

108
109
        self.checkpoint_directory = checkpoint_directory

110
        create_folder(experiment_directory)
111
        create_folder(experiment_directory_path)
112
        create_folder(os.path.join(
113
            experiment_directory_path, self.checkpoint_directory))
114
115
116
117

        self.start_epoch = 1
        self.start_iteration = 1

118
119
120
121
122
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
123

124
125
        self.early_stop = False

126
127
128
129
        if crop_flag == False:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
        elif crop_flag == True:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
130

131
        self.save_model_directory = save_model_directory
132
        self.final_model_output_file = experiment_name + ".pth.tar"
133

134
135
        self.best_score_early_stop = None
        self.counter_early_stop = 0
136
137
        self.previous_loss = None
        self.previous_MSE = None
138
        self.valid_epoch = None
139

140
141
        if use_last_checkpoint:
            self.load_checkpoint()
142
143
144
            self.EarlyStopping = EarlyStopping(patience=2, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop)
        else:
            self.EarlyStopping = EarlyStopping(patience=2, min_delta=0)
145

146
    def train(self, train_loader, validation_loader):
147
148
149
150
151
152
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
153
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
154
155

        Returns:
156
            trained model
157
158
159
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
160
        dataloaders = {'train': train_loader, 'validation': validation_loader}
161
162

        if torch.cuda.is_available():
163
164
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
165
166
167
168
169

        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
170
        if torch.cuda.is_available():
171
172
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
173
174
        else:
            print('Device Type: {}'.format(self.device))
175
176
177
178
179
180
181
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
182
183
184
185
186

            if self.early_stop == True:
                print("ATTENTION!: Training stopped due to previous early stop flag!")
                break

187
188
            print("Epoch {}/{}".format(epoch, self.number_epochs))

189
            for phase in ['train', 'validation']:
190
191
192
                print('-> Phase: {}'.format(phase))

                losses = []
Andrei Roibu's avatar
Andrei Roibu committed
193
                MSEs = []
194
195
196
197
198
199
200
201

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
202
                    y = sampled_batch[1].type(torch.FloatTensor)
203

204
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
205
206
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei Roibu's avatar
Andrei Roibu committed
207
208
209

                    MNI152_T1_2mm_brain_mask = torch.unsqueeze(
                        torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
210

211
                    if model.test_if_cuda:
212
213
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
214
215
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
216

217
                    y_hat = model(X)   # Forward pass & Masking
218

219
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
220

221
222
                    loss = self.loss_function(y_hat, y)  # Loss computation
                    # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
Andrei Roibu's avatar
Andrei Roibu committed
223
224
225
226

                    # We also calculate a separate MSE for cost function comparison!
                    MSE = self.MSE(y_hat, y)
                    MSEs.append(MSE.item())
227
228

                    if phase == 'train':
229
230
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
231
232
233
234
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

235
236
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
237

238
239
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
240
                    losses.append(loss.item())
241
242
243

                    # Clear the memory

Andrei Roibu's avatar
Andrei Roibu committed
244
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
245
246
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
247
                    if phase == 'validation':
248

249
250
251
252
253
254
255
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

256
257
                    if phase == 'train':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch)
Andrei Roibu's avatar
Andrei Roibu committed
258
                        self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
259
                    elif phase == 'validation':
Andrei Roibu's avatar
Andrei Roibu committed
260
                        self.LogWriter.loss_per_epoch(
261
262
                            losses, phase, epoch, previous_loss=self.previous_loss)
                        self.previous_loss = np.mean(losses)
Andrei Roibu's avatar
Andrei Roibu committed
263
                        self.LogWriter.MSE_per_epoch(
264
265
                            MSEs, phase, epoch, previous_loss=self.previous_MSE)
                        self.previous_MSE = np.mean(MSEs)
266

267
                    if phase == 'validation':
268
269
                        early_stop, best_score_early_stop, counter_early_stop = self.EarlyStopping(np.mean(losses))

270
                        self.early_stop = early_stop
271
272
                        self.best_score_early_stop = best_score_early_stop
                        self.counter_early_stop = counter_early_stop
273

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                        checkpoint_name = os.path.join(
                            self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)

                        if self.counter_early_stop == 0:
                            self.valid_epoch = epoch

                        self.save_checkpoint(state={'epoch': epoch + 1,
                                                    'start_iteration': iteration + 1,
                                                    'arch': self.model_name,
                                                    'state_dict': model.state_dict(),
                                                    'optimizer': optimizer.state_dict(),
                                                    'scheduler': learning_rate_scheduler.state_dict(),
                                                    'best_score_early_stop': self.best_score_early_stop,
                                                    'counter_early_stop': self.counter_early_stop,
                                                    'previous_loss': self.previous_loss,
                                                    'previous_MSE': self.previous_MSE,
                                                    'early_stop': self.early_stop
                                                    },
                                             filename=checkpoint_name
                                             )
294
295
296
                if phase == 'train':
                    learning_rate_scheduler.step()

297
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
298

299
300
301
302
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
303
                self.load_checkpoint(epoch=self.valid_epoch)
304
305
306
                break
            else:
                continue
307

308
309
310
311
312
313
314
        model_output_path = os.path.join(
            self.save_model_directory, self.final_model_output_file)

        create_folder(self.save_model_directory)

        model.save(model_output_path)

315
        self.LogWriter.close()
316
317
318
319
320
321
322

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
323
        print('Final Model Saved in: {}'.format(model_output_path))
324
325
        print('****************************************************************')

326
327
328
329
330
331
332
333
334
335
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
336

337
    def load_checkpoint(self, epoch=None):
338
339
340
341
342
343
344
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
345

346
        if epoch is not None:
347
            checkpoint_file_path = os.path.join(
348
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
349
350
            self._checkpoint_reader(checkpoint_file_path)
        else:
351
            universal_path = os.path.join(
352
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
353
354
355
356
357
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
358
359
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
360
361
362
                self._checkpoint_reader(checkpoint_file_path)

            else:
363
                self.LogWriter.log("No Checkpoint found at {}".format(
364
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
365

366
367
368
369
370
371
372
373
374
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

375
376
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
377
378
379
380

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
381
        # We are not loading the model_name as we might want to pre-train a model and then use it.
Andrei Roibu's avatar
Andrei Roibu committed
382
383
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
384
385
        self.best_score_early_stop = checkpoint['best_score_early_stop']
        self.counter_early_stop = checkpoint['counter_early_stop']
386
387
        self.previous_loss = checkpoint['previous_loss']
        self.previous_MSE = checkpoint['previous_MSE']
388
        self.early_stop = checkpoint['early_stop']
389

390
        for state in self.optimizer.state.values():
391
            for key, value in state.items():
392
393
394
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

Andrei Roibu's avatar
Andrei Roibu committed
395
        self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler'])
396
397
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))