Commit 3c9ed801 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

defined class, wrote constructor and def functions

parent be48e402
"""Brain Mapper U-Net Solver
Description:
-------------
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
Usage
-------------
To use this module, import it and instantiate is as you wish:
from solver import Solver
"""
import os
import numpy as np
import torch
checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'
class Solver():
"""Solver class for the BrainMapper U-net.
This class contains the pytorch implementation of the U-net solver required for the BrainMapper project.
Args:
model (class): BrainMapper model class
experiment_name (str): Name of the experiment
device (int/str): Device type used for training (int - GPU id, str- CPU)
number_of_classes (int): Number of classes
optimizer (class): Pytorch class of desired optimizer
optimizer_arguments (dict): Dictionary of arguments to be optimized
loss_function (func): Function describing the desired loss function
model_name (str): Name of the model
labels (arr): Vector/Array of labels (if applicable)
number_epochs (int): Number of training epochs
loss_log_period (int): Period for writing loss value
learning_rate_scheduler_step_size (int): Period of learning rate decay
learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay
use_last_checkpoint (bool): Flag for loading the previous checkpoint
experiment_directory (str): Experiment output directory name
logs_directory (str): Directory for outputing training logs
Returns:
trained model(?) - working on this!
Raises:
None
"""
def __init__(self,
model,
device,
number_of_classes,
experiment_name,
optimizer = torch.optim.Adam,
optimizer_arguments = {},
loss_function = loss_function, # Need to define
model_name = 'BrainMapper',
labels = None,
number_epochs = 10,
loss_log_period = 5,
learning_rate_scheduler_step_size = 5,
learning_rate_scheduler_gamma = 0.5,
use_last_checkpoint = True,
experiment_directory = 'experiments',
logs_directory = 'logs'
):
self.model = model
self.device = device
pass
def train():
pass
def save_model():
pass
def save_checkpoint():
pass
def load_checkpoint():
pass
def _load_checkpoint_file():
# Name is private = can't be called outisde of this module
pass
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment