solver.py 16.2 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from fsl.utils.image.roi import roi
21
from datetime import datetime
22
from utils.common_utils import create_folder
23
from utils.data_logging_utils import LogWriter
24
from utils.early_stopping import EarlyStopping
25
from torch.optim import lr_scheduler
26
27
28

checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs
52
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
53
54

    Returns:
55
        trained model - working on this!
56
57

    """
58

59
    def __init__(self,
60
61
62
63
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
64
                 optimizer,
65
                 optimizer_arguments={},
Andrei Roibu's avatar
Andrei Roibu committed
66
                 loss_function=torch.nn.MSELoss(),
67
68
69
70
71
72
73
74
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
75
                 logs_directory='logs',
76
77
                 checkpoint_directory='checkpoints',
                 save_model_directory='saved_models',
78
                 crop_flag = False
79
                 ):
80
81
82

        self.model = model
        self.device = device
83
84
85
86
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
87
            self.MSE = torch.nn.MSELoss().cuda(device)
88
89
        else:
            self.loss_function = loss_function
90
            self.MSE = torch.nn.MSELoss()
Andrei Roibu's avatar
Andrei Roibu committed
91

92
93
94
95
        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
96

97
98
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
99
100
101
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

102
103
        self.use_last_checkpoint = use_last_checkpoint

104
        experiment_directory_path = os.path.join(
105
            experiment_directory, experiment_name)
106
        self.experiment_directory_path = experiment_directory_path
107

108
109
        self.checkpoint_directory = checkpoint_directory

110
        create_folder(experiment_directory)
111
        create_folder(experiment_directory_path)
112
        create_folder(os.path.join(
113
            experiment_directory_path, self.checkpoint_directory))
114
115
116
117

        self.start_epoch = 1
        self.start_iteration = 1

118
119
120
121
122
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
123

124
125
        self.early_stop = False

126
127
128
129
        if crop_flag == False:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
        elif crop_flag == True:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
130

131
        self.save_model_directory = save_model_directory
132
        self.final_model_output_file = experiment_name + ".pth.tar"
133

134
135
136
        self.best_score_early_stop = None
        self.counter_early_stop = 0

137
138
        if use_last_checkpoint:
            self.load_checkpoint()
139
140
141
            self.EarlyStopping = EarlyStopping(patience=2, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop)
        else:
            self.EarlyStopping = EarlyStopping(patience=2, min_delta=0)
142

143
    def train(self, train_loader, validation_loader):
144
145
146
147
148
149
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
150
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
151
152

        Returns:
153
            trained model
154
155
156
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
157
        dataloaders = {'train': train_loader, 'validation': validation_loader}
158
159

        if torch.cuda.is_available():
160
161
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
162

163
        previous_checkpoint = None
164
        previous_loss = None
Andrei Roibu's avatar
Andrei Roibu committed
165
        previous_MSE = None
166

167
168
169
170
        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
171
        if torch.cuda.is_available():
172
173
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
174
175
        else:
            print('Device Type: {}'.format(self.device))
176
177
178
179
180
181
182
183
184
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

185
            for phase in ['train', 'validation']:
186
187
188
                print('-> Phase: {}'.format(phase))

                losses = []
Andrei Roibu's avatar
Andrei Roibu committed
189
                MSEs = []
190
191
192
193
194
195
196
197

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
198
                    y = sampled_batch[1].type(torch.FloatTensor)
199

200
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
201
202
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei Roibu's avatar
Andrei Roibu committed
203
204
205

                    MNI152_T1_2mm_brain_mask = torch.unsqueeze(
                        torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
206

207
                    if model.test_if_cuda:
208
209
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
210
211
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
212

213
                    y_hat = model(X)   # Forward pass & Masking
214

215
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
216

217
218
                    loss = self.loss_function(y_hat, y)  # Loss computation
                    # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
Andrei Roibu's avatar
Andrei Roibu committed
219
220
221
222

                    # We also calculate a separate MSE for cost function comparison!
                    MSE = self.MSE(y_hat, y)
                    MSEs.append(MSE.item())
223
224

                    if phase == 'train':
225
226
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
227
228
229
230
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

231
232
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
233

234
235
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
236
                    losses.append(loss.item())
237
238
239

                    # Clear the memory

Andrei Roibu's avatar
Andrei Roibu committed
240
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
241
242
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
243
                    if phase == 'validation':
244

245
246
247
248
249
250
251
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

252
253
                    if phase == 'train':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch)
Andrei Roibu's avatar
Andrei Roibu committed
254
                        self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
255
                    elif phase == 'validation':
Andrei Roibu's avatar
Andrei Roibu committed
256
257
                        self.LogWriter.loss_per_epoch(
                            losses, phase, epoch, previous_loss=previous_loss)
258
                        previous_loss = np.mean(losses)
Andrei Roibu's avatar
Andrei Roibu committed
259
260
261
                        self.LogWriter.MSE_per_epoch(
                            MSEs, phase, epoch, previous_loss=previous_MSE)
                        previous_MSE = np.mean(MSEs)
262

263
                    if phase == 'validation':
264
                        early_stop, save_checkpoint, best_score_early_stop, counter_early_stop = self.EarlyStopping(
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
265
                            np.mean(losses))
266
                        self.early_stop = early_stop
267
268
                        self.best_score_early_stop = best_score_early_stop
                        self.counter_early_stop = counter_early_stop
269
                        if save_checkpoint == True:
270
                            validation_loss = np.mean(losses)
271
272
                            checkpoint_name = os.path.join(
                                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
273
274
275
276
277
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
278
279
280
                                                        'scheduler': learning_rate_scheduler.state_dict(),
                                                        'best_score_early_stop': self.best_score_early_stop,
                                                        'counter_early_stop': self.counter_early_stop
281
                                                        },
282
                                                 filename=checkpoint_name
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
283
                                                 )
284
285
286
287
288
289

                            if previous_checkpoint != None:
                                os.remove(previous_checkpoint)
                                previous_checkpoint = checkpoint_name
                            else:
                                previous_checkpoint = checkpoint_name
290

291
292
293
                if phase == 'train':
                    learning_rate_scheduler.step()

294
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
295

296
297
298
299
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
300
                self.load_checkpoint()
301
302
303
                break
            else:
                continue
304

305
306
307
308
309
310
311
        model_output_path = os.path.join(
            self.save_model_directory, self.final_model_output_file)

        create_folder(self.save_model_directory)

        model.save(model_output_path)

312
        self.LogWriter.close()
313
314
315
316
317
318
319

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
320
        print('Final Model Saved in: {}'.format(model_output_path))
321
322
        print('****************************************************************')

323
324
325
        if self.start_epoch >= self.number_epochs+1:
            validation_loss = None

326
327
        return validation_loss

328
329
330
331
332
333
334
335
336
337
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
338

339
    def load_checkpoint(self, epoch=None):
340
341
342
343
344
345
346
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
347

348
        if epoch is not None:
349
            checkpoint_file_path = os.path.join(
350
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
351
352
            self._checkpoint_reader(checkpoint_file_path)
        else:
353
            universal_path = os.path.join(
354
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
355
356
357
358
359
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
360
361
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
362
363
364
                self._checkpoint_reader(checkpoint_file_path)

            else:
365
                self.LogWriter.log("No Checkpoint found at {}".format(
366
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
367

368
369
370
371
372
373
374
375
376
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

377
378
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
379
380
381
382

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
383
        # We are not loading the model_name as we might want to pre-train a model and then use it.
Andrei Roibu's avatar
Andrei Roibu committed
384
385
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
386
387
        self.best_score_early_stop = checkpoint['best_score_early_stop']
        self.counter_early_stop = checkpoint['counter_early_stop']
388

389
        for state in self.optimizer.state.values():
390
            for key, value in state.items():
391
392
393
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

Andrei Roibu's avatar
Andrei Roibu committed
394
        self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler'])
395
396
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))