solver.py 15.2 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from datetime import datetime
21
from utils.losses import MSELoss
22
from utils.common_utils import create_folder
23
from utils.data_logging_utils import LogWriter
24
from utils.early_stopping import EarlyStopping
25
from torch.optim import lr_scheduler
26
27
28

checkpoint_extension = 'path.tar'

29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs

    Returns:
54
        trained model - working on this!
55
56

    """
57

58
    def __init__(self,
59
60
61
62
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
63
                 optimizer,
64
                 optimizer_arguments={},
65
66
67
                 loss_function=MSELoss(),
                #  loss_function=torch.nn.L1Loss(),
                #  loss_function=torch.nn.CosineEmbeddingLoss(),
68
69
70
71
72
73
74
75
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
76
                 logs_directory='logs',
77
78
79
                 checkpoint_directory='checkpoints',
                 save_model_directory='saved_models',
                 final_model_output_file='finetuned_alldata.pth.tar'
80
                 ):
81
82
83

        self.model = model
        self.device = device
84
85
86
87
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
Andrei Roibu's avatar
Andrei Roibu committed
88
            self.MSE = MSELoss().cuda(device)
89
90
        else:
            self.loss_function = loss_function
Andrei Roibu's avatar
Andrei Roibu committed
91
92
            self.MSE = MSELoss()

93
94
95
96
        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
97

98
99
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
100
101
102
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

103
104
        self.use_last_checkpoint = use_last_checkpoint

105
        experiment_directory_path = os.path.join(
106
            experiment_directory, experiment_name)
107
        self.experiment_directory_path = experiment_directory_path
108

109
110
        self.checkpoint_directory = checkpoint_directory

111
        create_folder(experiment_directory)
112
        create_folder(experiment_directory_path)
113
        create_folder(os.path.join(
114
            experiment_directory_path, self.checkpoint_directory))
115
116
117
118

        self.start_epoch = 1
        self.start_iteration = 1

119
120
121
122
123
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
124

125
        self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
126
127
        self.early_stop = False

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
128
129
        self.MNI152_T1_2mm_brain_mask = torch.from_numpy(
            Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
130

131
132
133
        self.save_model_directory = save_model_directory
        self.final_model_output_file = final_model_output_file

134
135
136
        if use_last_checkpoint:
            self.load_checkpoint()

137
    def train(self, train_loader, validation_loader):
138
139
140
141
142
143
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
144
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
145
146

        Returns:
147
            trained model
148
149
150
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
151
        dataloaders = {'train': train_loader, 'validation': validation_loader}
152
153

        if torch.cuda.is_available():
154
155
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
156

157
        previous_checkpoint = None
158
        previous_loss = None
Andrei Roibu's avatar
Andrei Roibu committed
159
        previous_MSE = None
160

161
162
163
164
        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
165
        if torch.cuda.is_available():
166
167
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
168
169
        else:
            print('Device Type: {}'.format(self.device))
170
171
172
173
174
175
176
177
178
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

179
            for phase in ['train', 'validation']:
180
181
182
                print('-> Phase: {}'.format(phase))

                losses = []
Andrei Roibu's avatar
Andrei Roibu committed
183
                MSEs = []
184
185
186
187
188
189
190
191

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
192
                    y = sampled_batch[1].type(torch.FloatTensor)
193

194
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
195
196
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei Roibu's avatar
Andrei Roibu committed
197
198
199

                    MNI152_T1_2mm_brain_mask = torch.unsqueeze(
                        torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
200

201
                    if model.test_if_cuda:
202
203
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
204
205
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
206

207
                    y_hat = model(X)   # Forward pass & Masking
208

209
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
210

211
212
                    loss = self.loss_function(y_hat, y)  # Loss computation
                    # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
Andrei Roibu's avatar
Andrei Roibu committed
213
214
215
216

                    # We also calculate a separate MSE for cost function comparison!
                    MSE = self.MSE(y_hat, y)
                    MSEs.append(MSE.item())
217
218

                    if phase == 'train':
219
220
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
221
222
223
224
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

225
226
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
227

228
229
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
230
                    losses.append(loss.item())
231
232
233

                    # Clear the memory

Andrei Roibu's avatar
Andrei Roibu committed
234
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
235
236
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
237
                    if phase == 'validation':
238

239
240
241
242
243
244
245
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

246
247
                    if phase == 'train':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch)
Andrei Roibu's avatar
Andrei Roibu committed
248
                        self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
249
                    elif phase == 'validation':
Andrei Roibu's avatar
Andrei Roibu committed
250
251
                        self.LogWriter.loss_per_epoch(
                            losses, phase, epoch, previous_loss=previous_loss)
252
                        previous_loss = np.mean(losses)
Andrei Roibu's avatar
Andrei Roibu committed
253
254
255
                        self.LogWriter.MSE_per_epoch(
                            MSEs, phase, epoch, previous_loss=previous_MSE)
                        previous_MSE = np.mean(MSEs)
256

257
                    if phase == 'validation':
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
258
259
                        early_stop, save_checkpoint = self.EarlyStopping(
                            np.mean(losses))
260
261
                        self.early_stop = early_stop
                        if save_checkpoint == True:
262
                            validation_loss = np.mean(losses)
263
264
                            checkpoint_name = os.path.join(
                                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
265
266
267
268
269
270
271
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
                                                        'scheduler': learning_rate_scheduler.state_dict()
                                                        },
272
                                                 filename=checkpoint_name
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
273
                                                 )
274
275
276
277
278
279

                            if previous_checkpoint != None:
                                os.remove(previous_checkpoint)
                                previous_checkpoint = checkpoint_name
                            else:
                                previous_checkpoint = checkpoint_name
280

281
282
283
                if phase == 'train':
                    learning_rate_scheduler.step()

284
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
285

286
287
288
289
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
290
                self.load_checkpoint()
291
292
293
                break
            else:
                continue
294

295
296
297
298
299
300
301
        model_output_path = os.path.join(
            self.save_model_directory, self.final_model_output_file)

        create_folder(self.save_model_directory)

        model.save(model_output_path)

302
        self.LogWriter.close()
303
304
305
306
307
308
309

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
310
        print('Final Model Saved in: {}'.format(model_output_path))
311
312
        print('****************************************************************')

313
314
315
        if self.start_epoch >= self.number_epochs+1:
            validation_loss = None

316
317
        return validation_loss

318
319
320
321
322
323
324
325
326
327
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
328

329
    def load_checkpoint(self, epoch=None):
330
331
332
333
334
335
336
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
337

338
        if epoch is not None:
339
            checkpoint_file_path = os.path.join(
340
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
341
342
            self._checkpoint_reader(checkpoint_file_path)
        else:
343
            universal_path = os.path.join(
344
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
345
346
347
348
349
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
350
351
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
352
353
354
                self._checkpoint_reader(checkpoint_file_path)

            else:
355
                self.LogWriter.log("No Checkpoint found at {}".format(
356
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
357

358
359
360
361
362
363
364
365
366
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

367
368
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
369
370
371
372

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
373
        # We are not loading the model_name as we might want to pre-train a model and then use it.
374
375
376
377
        self.model.load_state_dict = checkpoint['state_dict']
        self.optimizer.load_state_dict = checkpoint['optimizer']

        for state in self.optimizer.state.values():
378
            for key, value in state.items():
379
380
381
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

382
        self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
383
384
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))