solver.py 2.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Brain Mapper U-Net Solver

Description:
-------------
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.

Usage
-------------
To use this module, import it and instantiate is as you wish:

    from solver import Solver

"""

import os
import numpy as np
import torch

checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'

class Solver():
    """Solver class for the BrainMapper U-net.

    This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.

    Args:
        model (class): BrainMapper model class
        experiment_name (str): Name of the experiment
        device (int/str): Device type used for training (int - GPU id, str- CPU)
        number_of_classes (int): Number of classes
        optimizer (class): Pytorch class of desired optimizer
        optimizer_arguments (dict): Dictionary of arguments to be optimized
        loss_function (func): Function describing the desired loss function
        model_name (str): Name of the model
        labels (arr): Vector/Array of labels (if applicable)
        number_epochs (int): Number of training epochs
        loss_log_period (int): Period for writing loss value
        learning_rate_scheduler_step_size (int): Period of learning rate decay
        learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
        use_last_checkpoint (bool): Flag for loading the previous checkpoint
        experiment_directory (str): Experiment output directory name
        logs_directory (str): Directory for outputing training logs

    
    Returns:
        trained model(?) - working on this!

    Raises:
        None
    """
    
    def __init__(self,
                model,
                device,
                number_of_classes,
                experiment_name,
                optimizer = torch.optim.Adam,
                optimizer_arguments = {},
                loss_function =  loss_function, # Need to define
                model_name = 'BrainMapper',
                labels = None,
                number_epochs = 10,
                loss_log_period = 5,
                learning_rate_scheduler_step_size = 5,
                learning_rate_scheduler_gamma = 0.5,
                use_last_checkpoint = True,
                experiment_directory = 'experiments',
                logs_directory = 'logs'
                ):

        self.model = model
        self.device = device
        

        pass

    def train():
        pass

    def save_model():
        pass

    def save_checkpoint():
        pass

    def load_checkpoint():
        pass

    def _load_checkpoint_file():
        # Name is private = can't be called outisde of this module
        pass