Commit c4d7d0ad authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

wrote the data loader function

parent c87661ce
......@@ -14,11 +14,47 @@ Usage
In order to run the network, in the terminal, the user needs to pass it relevant arguments:
def load_data():
import torch
from utils.data_utils import get_datasets
import BrainMapperUNet as BrainMapperUNet
# Set the default floating point tensor type to FloatTensor
def load_data(data_parameters):
"""Dataset Loader
This function loads the training and testing datasets.
TODO: Will need to define if all the training data is loaded as bulk or individually!
data_parameters (dict): Dictionary containing relevant information for the datafiles.
data_parameters = {
data_directory: 'path/to/directory'
train_data_file: 'training_data'
train_output_targets: 'training_targets'
test_data_file: 'testing_data'
test_target_file: 'testing_targets'
train_data (dataset object): Pytorch map-style dataset object, mapping indices to training data samples.
test_data (dataset object): Pytorch map-style dataset object, mapping indices to testing data samples.
print("Data is loading...")
train_data, test_data = get_datasets(data_parameters)
print("Data has loaded!")
print("Training dataset size is {}".format(len(train_data)))
print("Testing dataset size is {}".format(len(test_data)))
return train_data, test_data
def train():
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment