Commit 32db8f36 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

built data loaders, started writing solver call

parent 3c9ed801
......@@ -19,6 +19,7 @@ In order to run the network, in the terminal, the user needs to pass it relevant
import torch
from utils.data_utils import get_datasets
import BrainMapperUNet as BrainMapperUNet
import torch.utils.data as data
# Set the default floating point tensor type to FloatTensor
......@@ -56,8 +57,70 @@ def load_data(data_parameters):
return train_data, test_data
def train():
pass
def train(data_parameters, training_parameters):
"""Name
Desc
Currently, the data loaded is set to have multiple sub-processes.
A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation)
Loader memory is also pinned, to speed up data transfer from CPU to GPU by using the page-locked memory.
Train data is also re-shuffled at each training epoch.
Args:
data_parameters(dict):
training_parameters(dict):{
paraters
}
network_parameters (dict): Contains information relevant parameters = {
parameters
}
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
None
Raises:
None
"""
train_data, test_data = load_data(data_parameters)
train_loader = data.DataLoader(
dataset= train_data,
batch_size= training_parameters['train_batch_size'],
shuffle= True,
num_workers= 4,
pin_memory= True
)
test_loader = data.DataLoader(
dataset= test_data,
batch_size= training_parameters['test_batch_size'],
shuffle= False,
num_workers= 4,
pin_memory= True
)
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
else:
BrainMapperModel = BrainMapperUNet(network_parameters)
solver = Solver(
# TODO - need to write the solver !
)
def evaluate_path():
pass
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment