solver.py 4.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""Brain Mapper U-Net Solver

Description:
-------------
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.

Usage
-------------
To use this module, import it and instantiate is as you wish:

    from solver import Solver

"""

import os
import numpy as np
import torch
18
19
20
from utils.losses import MSELoss
from utils.data_utils import create_folder
from torch.optim import lr_scheduler
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'

class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs

    
    Returns:
        trained model(?) - working on this!

    Raises:
        None
    """
    
    def __init__(self,
                model,
                device,
                number_of_classes,
                experiment_name,
                optimizer = torch.optim.Adam,
                optimizer_arguments = {},
63
                loss_function =  MSELoss(),
64
65
66
67
68
69
70
71
72
73
74
75
76
                model_name = 'BrainMapper',
                labels = None,
                number_epochs = 10,
                loss_log_period = 5,
                learning_rate_scheduler_step_size = 5,
                learning_rate_scheduler_gamma = 0.5,
                use_last_checkpoint = True,
                experiment_directory = 'experiments',
                logs_directory = 'logs'
                ):

        self.model = model
        self.device = device
77
78
79
80
81
82
83
84
85
86
87
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)

        if torch.cuda.is_available():
            self.loss_function = loss_function.cuda(device)
        else:
            self.loss_function = loss_function

        self.model_name = model_name
        self.labels = labels
        self.number_epochs = number_epochs
        self.loss_log_period = loss_log_period
88
        
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
                                                            step_size = learning_rate_scheduler_step_size,
                                                            gamma= learning_rate_scheduler_gamma)
        
        self.use_last_checkpoint = use_last_checkpoint

        experiment_directory_path = os.join.path(experiment_directory, experiment_name)
        self.experiment_directory_path = experiment_directory_path
        
        create_folder(experiment_directory_path)
        create_folder(os.join.path(experiment_directory_path, checkpoint_directory))

        self.start_epoch = 1
        self.start_iteration = 1
        self.best_mean_score = 0
        self.best_mean_epoch = 0

        if use_last_checkpoint:
            self.load_checkpoint()
109
110


111
    def train(self):
112
113
        pass

114
    def save_model(self):
115
116
        pass

117
    def save_checkpoint(self):
118
119
        pass

120
    def load_checkpoint(self):
121
122
        pass

123
    def _load_checkpoint_file(self):
124
125
126
        # Name is private = can't be called outisde of this module
        pass