solver.py 15.5 KB
Newer Older
1
2
3
4
"""Brain Mapper U-Net Solver

Description:

5
    This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
6

7
Usage:
8

9
10
11
    To use this module, import it and instantiate is as you wish:

        from solver import Solver
12
13
14
15
16
"""

import os
import numpy as np
import torch
17
18
import glob

19
from fsl.data.image import Image
20
from fsl.utils.image.roi import roi
21
from datetime import datetime
22
from utils.losses import MSELoss
23
from utils.common_utils import create_folder
24
from utils.data_logging_utils import LogWriter
25
from utils.early_stopping import EarlyStopping
26
from torch.optim import lr_scheduler
27
28
29

checkpoint_extension = 'path.tar'

30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs
53
        crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
54
55

    Returns:
56
        trained model - working on this!
57
58

    """
59

60
    def __init__(self,
61
62
63
64
                 model,
                 device,
                 number_of_classes,
                 experiment_name,
65
                 optimizer,
66
                 optimizer_arguments={},
67
                 loss_function=MSELoss(),
68
69
70
71
72
73
74
75
                 model_name='BrainMapper',
                 labels=None,
                 number_epochs=10,
                 loss_log_period=5,
                 learning_rate_scheduler_step_size=5,
                 learning_rate_scheduler_gamma=0.5,
                 use_last_checkpoint=True,
                 experiment_directory='experiments',
76
                 logs_directory='logs',
77
78
                 checkpoint_directory='checkpoints',
                 save_model_directory='saved_models',
79
                 crop_flag = False
80
                 ):
81
82
83

        self.model = model
        self.device = device
84
85
86
87
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
88
            self.MSE = torch.nn.MSELoss().cuda(device)
89
90
        else:
            self.loss_function = loss_function
91
            self.MSE = torch.nn.MSELoss()
Andrei Roibu's avatar
Andrei Roibu committed
92

93
94
95
96
        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
97

98
99
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
100
101
102
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)

103
104
        self.use_last_checkpoint = use_last_checkpoint

105
        experiment_directory_path = os.path.join(
106
            experiment_directory, experiment_name)
107
        self.experiment_directory_path = experiment_directory_path
108

109
110
        self.checkpoint_directory = checkpoint_directory

111
        create_folder(experiment_directory)
112
        create_folder(experiment_directory_path)
113
        create_folder(os.path.join(
114
            experiment_directory_path, self.checkpoint_directory))
115
116
117
118

        self.start_epoch = 1
        self.start_iteration = 1

119
120
121
122
123
        self.LogWriter = LogWriter(number_of_classes=number_of_classes,
                                   logs_directory=logs_directory,
                                   experiment_name=experiment_name,
                                   use_last_checkpoint=use_last_checkpoint,
                                   labels=labels)
124

125
        self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
126
127
        self.early_stop = False

128
129
130
131
        if crop_flag == False:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
        elif crop_flag == True:
            self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
132

133
        self.save_model_directory = save_model_directory
134
        self.final_model_output_file = experiment_name + ".pth.tar"
135

136
137
138
        if use_last_checkpoint:
            self.load_checkpoint()

139
    def train(self, train_loader, validation_loader):
140
141
142
143
144
145
        """Training Function

        This function trains a given model using the provided training data.

        Args:
            train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
146
            validation_loader (class):  Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
147
148

        Returns:
149
            trained model
150
151
152
        """

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
153
        dataloaders = {'train': train_loader, 'validation': validation_loader}
154
155

        if torch.cuda.is_available():
156
157
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU
158

159
        previous_checkpoint = None
160
        previous_loss = None
Andrei Roibu's avatar
Andrei Roibu committed
161
        previous_MSE = None
162

163
164
165
166
        print('****************************************************************')
        print('TRAINING IS STARTING!')
        print('=====================')
        print('Model Name: {}'.format(self.model_name))
167
        if torch.cuda.is_available():
168
169
            print('Device Type: {}'.format(
                torch.cuda.get_device_name(self.device)))
170
171
        else:
            print('Device Type: {}'.format(self.device))
172
173
174
175
176
177
178
179
180
        start_time = datetime.now()
        print('Started At: {}'.format(start_time))
        print('----------------------------------------')

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            print("Epoch {}/{}".format(epoch, self.number_epochs))

181
            for phase in ['train', 'validation']:
182
183
184
                print('-> Phase: {}'.format(phase))

                losses = []
Andrei Roibu's avatar
Andrei Roibu committed
185
                MSEs = []
186
187
188
189
190
191
192
193

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                for batch_index, sampled_batch in enumerate(dataloaders[phase]):
                    X = sampled_batch[0].type(torch.FloatTensor)
194
                    y = sampled_batch[1].type(torch.FloatTensor)
195

196
                    # We add an extra dimension (~ number of channels) for the 3D convolutions.
197
198
                    X = torch.unsqueeze(X, dim=1)
                    y = torch.unsqueeze(y, dim=1)
Andrei Roibu's avatar
Andrei Roibu committed
199
200
201

                    MNI152_T1_2mm_brain_mask = torch.unsqueeze(
                        torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0)
202

203
                    if model.test_if_cuda:
204
205
                        X = X.cuda(self.device, non_blocking=True)
                        y = y.cuda(self.device, non_blocking=True)
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
206
207
                        MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda(
                            self.device, non_blocking=True)
208

209
                    y_hat = model(X)   # Forward pass & Masking
210

211
                    y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask)
212

213
214
                    loss = self.loss_function(y_hat, y)  # Loss computation
                    # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True))
Andrei Roibu's avatar
Andrei Roibu committed
215
216
217
218

                    # We also calculate a separate MSE for cost function comparison!
                    MSE = self.MSE(y_hat, y)
                    MSEs.append(MSE.item())
219
220

                    if phase == 'train':
221
222
                        optimizer.zero_grad()  # Zero the parameter gradients
                        loss.backward()  # Backward propagation
223
224
225
226
                        optimizer.step()

                        if batch_index % self.loss_log_period == 0:

227
228
                            self.LogWriter.loss_per_iteration(
                                loss.item(), batch_index, iteration)
229

230
231
                        iteration += 1

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
232
                    losses.append(loss.item())
233
234
235

                    # Clear the memory

Andrei Roibu's avatar
Andrei Roibu committed
236
                    del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE
237
238
                    torch.cuda.empty_cache()

Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
239
                    if phase == 'validation':
240

241
242
243
244
245
246
247
                        if batch_index != len(dataloaders[phase]) - 1:
                            print("#", end='', flush=True)
                        else:
                            print("100%", flush=True)

                with torch.no_grad():

248
249
                    if phase == 'train':
                        self.LogWriter.loss_per_epoch(losses, phase, epoch)
Andrei Roibu's avatar
Andrei Roibu committed
250
                        self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
251
                    elif phase == 'validation':
Andrei Roibu's avatar
Andrei Roibu committed
252
253
                        self.LogWriter.loss_per_epoch(
                            losses, phase, epoch, previous_loss=previous_loss)
254
                        previous_loss = np.mean(losses)
Andrei Roibu's avatar
Andrei Roibu committed
255
256
257
                        self.LogWriter.MSE_per_epoch(
                            MSEs, phase, epoch, previous_loss=previous_MSE)
                        previous_MSE = np.mean(MSEs)
258

259
                    if phase == 'validation':
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
260
261
                        early_stop, save_checkpoint = self.EarlyStopping(
                            np.mean(losses))
262
263
                        self.early_stop = early_stop
                        if save_checkpoint == True:
264
                            validation_loss = np.mean(losses)
265
266
                            checkpoint_name = os.path.join(
                                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
267
268
269
270
271
272
273
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
                                                        'scheduler': learning_rate_scheduler.state_dict()
                                                        },
274
                                                 filename=checkpoint_name
Andrei-Claudiu Roibu's avatar
Andrei-Claudiu Roibu committed
275
                                                 )
276
277
278
279
280
281

                            if previous_checkpoint != None:
                                os.remove(previous_checkpoint)
                                previous_checkpoint = checkpoint_name
                            else:
                                previous_checkpoint = checkpoint_name
282

283
284
285
                if phase == 'train':
                    learning_rate_scheduler.step()

286
            print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
287

288
289
290
291
            # Early Stop Condition

            if self.early_stop == True:
                print("ATTENTION!: Training stopped early to prevent overfitting!")
292
                self.load_checkpoint()
293
294
295
                break
            else:
                continue
296

297
298
299
300
301
302
303
        model_output_path = os.path.join(
            self.save_model_directory, self.final_model_output_file)

        create_folder(self.save_model_directory)

        model.save(model_output_path)

304
        self.LogWriter.close()
305
306
307
308
309
310
311

        print('----------------------------------------')
        print('TRAINING IS COMPLETE!')
        print('=====================')
        end_time = datetime.now()
        print('Completed At: {}'.format(end_time))
        print('Training Duration: {}'.format(end_time - start_time))
312
        print('Final Model Saved in: {}'.format(model_output_path))
313
314
        print('****************************************************************')

315
316
317
        if self.start_epoch >= self.number_epochs+1:
            validation_loss = None

318
319
        return validation_loss

320
321
322
323
324
325
326
327
328
329
    def save_checkpoint(self, state, filename):
        """General Checkpoint Save

        This function saves a general checkpoint for inference and/or resuming training

        Args:
            state (dict): Dictionary of all the relevant model components
        """

        torch.save(state, filename)
330

331
    def load_checkpoint(self, epoch=None):
332
333
334
335
336
337
338
        """General Checkpoint Loader

        This function loads a previous checkpoint for inference and/or resuming training

        Args:
            epoch (int): Current epoch value
        """
339

340
        if epoch is not None:
341
            checkpoint_file_path = os.path.join(
342
                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
343
344
            self._checkpoint_reader(checkpoint_file_path)
        else:
345
            universal_path = os.path.join(
346
                self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
347
348
349
350
351
            files_in_universal_path = glob.glob(universal_path)

            # We will sort through all the files in path to see which one is most recent

            if len(files_in_universal_path) > 0:
352
353
                checkpoint_file_path = max(
                    files_in_universal_path, key=os.path.getatime)
354
355
356
                self._checkpoint_reader(checkpoint_file_path)

            else:
357
                self.LogWriter.log("No Checkpoint found at {}".format(
358
                    os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
359

360
361
362
363
364
365
366
367
368
    def _checkpoint_reader(self, checkpoint_file_path):
        """Checkpoint Reader

        This private function reads a checkpoint file and then loads the relevant variables

        Args:
            checkpoint_file_path (str): path to checkpoint file
        """

369
370
        self.LogWriter.log(
            "Loading Checkpoint {}".format(checkpoint_file_path))
371
372
373
374

        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
375
        # We are not loading the model_name as we might want to pre-train a model and then use it.
Andrei Roibu's avatar
Andrei Roibu committed
376
377
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
378

379
        for state in self.optimizer.state.values():
380
            for key, value in state.items():
381
382
383
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

Andrei Roibu's avatar
Andrei Roibu committed
384
        self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler'])
385
386
        self.LogWriter.log(
            "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))