solver.py 15.4 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from fsl.utils.image.roi import roi
21
from datetime import datetime
22
from utils.common_utils import create_folder
23
from utils.data_logging_utils import LogWriter
24
from utils.early_stopping import EarlyStopping
25
from torch.optim import lr_scheduler
26
27
28

checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs
52
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
53
54

    Returns:
55
        trained model - working on this!
56
57

    """
58

59
    def __init__(self,
60
61
62
63
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
64
                 optimizer,
65
                 optimizer_arguments={},
Andrei Roibu's avatar
Andrei Roibu committed
66
                 loss_function=torch.nn.MSELoss(),
67
68
69
70
71
72
73
74
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
75
                 logs_directory='logs',
76
77
                 checkpoint_directory='checkpoints',
                 save_model_directory='saved_models',
78
                 crop_flag = False
79
                 ):
80
81
82

        self.model = model
        self.device = device
83
84
85
86
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
87
            self.MSE = torch.nn.MSELoss().cuda(device)
88
89
        else:
            self.loss_function = loss_function
90
            self.MSE = torch.nn.MSELoss()
Andrei Roibu's avatar
Andrei Roibu committed
91

92
93
94
95
        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
96

97
98
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
99
100
101
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

102
103
        self.use_last_checkpoint = use_last_checkpoint

104
        experiment_directory_path = os.path.join(
105
            experiment_directory, experiment_name)
106
        self.experiment_directory_path = experiment_directory_path
107

108
109
        self.checkpoint_directory = checkpoint_directory

110
        create_folder(experiment_directory)
111
        create_folder(experiment_directory_path)
112
        create_folder(os.path.join(
113
            experiment_directory_path, self.checkpoint_directory))
114
115
116
117

        self.start_epoch = 1
        self.start_iteration = 1

118
119
120
121
122
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
123

124
        self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
125
126
        self.early_stop = False

127
128
129
130
        if crop_flag == False:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
        elif crop_flag == True:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
131

132
        self.save_model_directory = save_model_directory
133
        self.final_model_output_file = experiment_name + ".pth.tar"
134

135
136
137
        if use_last_checkpoint:
            self.load_checkpoint()

138
    def train(self, train_loader, validation_loader):
139
140
141
142
143
144
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
145
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
146
147

        Returns:
148
            trained model
149
150
151
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
152
        dataloaders = {'train': train_loader, 'validation': validation_loader}
153
154

        if torch.cuda.is_available():
155
156
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
157

158
        previous_checkpoint = None
159
        previous_loss = None
Andrei Roibu's avatar
Andrei Roibu committed
160
        previous_MSE = None
161

162
163
164
165
        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
166
        if torch.cuda.is_available():
167
168
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
169
170
        else:
            print('Device Type: {}'.format(self.device))
171
172
173
174
175
176
177
178
179
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

180
            for phase in ['train', 'validation']:
181
182
183
                print('-> Phase: {}'.format(phase))

                losses = []
Andrei Roibu's avatar
Andrei Roibu committed
184
                MSEs = []
185
186
187
188
189
190
191
192

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
193
                    y = sampled_batch[1].type(torch.FloatTensor)
194

195
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
196
197
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei Roibu's avatar
Andrei Roibu committed
198
199
200

                    MNI152_T1_2mm_brain_mask = torch.unsqueeze(
                        torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
201

202
                    if model.test_if_cuda:
203
204
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
205
206
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
207

208
                    y_hat = model(X)   # Forward pass & Masking
209

210
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
211

212
213
                    loss = self.loss_function(y_hat, y)  # Loss computation
                    # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
Andrei Roibu's avatar
Andrei Roibu committed
214
215
216
217

                    # We also calculate a separate MSE for cost function comparison!
                    MSE = self.MSE(y_hat, y)
                    MSEs.append(MSE.item())
218
219

                    if phase == 'train':
220
221
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
222
223
224
225
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

226
227
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
228

229
230
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
231
                    losses.append(loss.item())
232
233
234

                    # Clear the memory

Andrei Roibu's avatar
Andrei Roibu committed
235
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
236
237
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
238
                    if phase == 'validation':
239

240
241
242
243
244
245
246
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

247
248
                    if phase == 'train':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch)
Andrei Roibu's avatar
Andrei Roibu committed
249
                        self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
250
                    elif phase == 'validation':
Andrei Roibu's avatar
Andrei Roibu committed
251
252
                        self.LogWriter.loss_per_epoch(
                            losses, phase, epoch, previous_loss=previous_loss)
253
                        previous_loss = np.mean(losses)
Andrei Roibu's avatar
Andrei Roibu committed
254
255
256
                        self.LogWriter.MSE_per_epoch(
                            MSEs, phase, epoch, previous_loss=previous_MSE)
                        previous_MSE = np.mean(MSEs)
257

258
                    if phase == 'validation':
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
259
260
                        early_stop, save_checkpoint = self.EarlyStopping(
                            np.mean(losses))
261
262
                        self.early_stop = early_stop
                        if save_checkpoint == True:
263
                            validation_loss = np.mean(losses)
264
265
                            checkpoint_name = os.path.join(
                                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
266
267
268
269
270
271
272
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
                                                        'scheduler': learning_rate_scheduler.state_dict()
                                                        },
273
                                                 filename=checkpoint_name
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
274
                                                 )
275
276
277
278
279
280

                            if previous_checkpoint != None:
                                os.remove(previous_checkpoint)
                                previous_checkpoint = checkpoint_name
                            else:
                                previous_checkpoint = checkpoint_name
281

282
283
284
                if phase == 'train':
                    learning_rate_scheduler.step()

285
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
286

287
288
289
290
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
291
                self.load_checkpoint()
292
293
294
                break
            else:
                continue
295

296
297
298
299
300
301
302
        model_output_path = os.path.join(
            self.save_model_directory, self.final_model_output_file)

        create_folder(self.save_model_directory)

        model.save(model_output_path)

303
        self.LogWriter.close()
304
305
306
307
308
309
310

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
311
        print('Final Model Saved in: {}'.format(model_output_path))
312
313
        print('****************************************************************')

314
315
316
        if self.start_epoch >= self.number_epochs+1:
            validation_loss = None

317
318
        return validation_loss

319
320
321
322
323
324
325
326
327
328
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
329

330
    def load_checkpoint(self, epoch=None):
331
332
333
334
335
336
337
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
338

339
        if epoch is not None:
340
            checkpoint_file_path = os.path.join(
341
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
342
343
            self._checkpoint_reader(checkpoint_file_path)
        else:
344
            universal_path = os.path.join(
345
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
346
347
348
349
350
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
351
352
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
353
354
355
                self._checkpoint_reader(checkpoint_file_path)

            else:
356
                self.LogWriter.log("No Checkpoint found at {}".format(
357
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
358

359
360
361
362
363
364
365
366
367
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

368
369
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
370
371
372
373

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
374
        # We are not loading the model_name as we might want to pre-train a model and then use it.
Andrei Roibu's avatar
Andrei Roibu committed
375
376
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
377

378
        for state in self.optimizer.state.values():
379
            for key, value in state.items():
380
381
382
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

Andrei Roibu's avatar
Andrei Roibu committed
383
        self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler'])
384
385
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))