Commit 21bbaae0 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

rewrote the get_datasets function for new DataMapper

parent e6579311
......@@ -103,8 +103,8 @@ class DataMapper(data.Dataset):
Args:
filename (str): Path to file containing the relevant volume indicator numbers
data_directory (str): Directory where the various subjects are stored.
train_data_file (str): Intenal path for each subject to the relevant normalized summed dMRI tracts
train_output_targets (str): Internal path for each subject to the relevant rsfMRI data
data_file (str): Intenal path for each subject to the relevant normalized summed dMRI tracts
output_targets (str): Internal path for each subject to the relevant rsfMRI data
Returns:
X_volume (torch.tensor): Tensor representation of the input data
......@@ -113,14 +113,14 @@ class DataMapper(data.Dataset):
"""
def __init__(self, filename, data_directory, train_data_file, train_output_targets):
def __init__(self, filename, data_directory, data_file, output_targets):
# Initialize everything, and only store in memory the text data file.
# Memory usage limited by only storing the text string information, not the actual volumes.
# TODO: Currently, the timepoint in the fMRI data (y_volume) is hardcoded, only loading in the RSN. This needs to be updated in later iterations.
self.filename = filename
self.data_directory = data_directory
self.train_data_file = train_data_file
self.train_output_targets = train_output_targets
self.data_file = data_file
self.output_targets = output_targets
self.sample_pairs = []
self._get_datasets()
......@@ -141,7 +141,6 @@ class DataMapper(data.Dataset):
Helper function which reads all the various strings and generates the required paths.
"""
# We read the file strings, and then come up with the full paths
with open(self.filename) as files:
lines = files.read().split('\n')
......@@ -150,8 +149,8 @@ class DataMapper(data.Dataset):
if line == '':
pass
else:
X_path = os.path.join(self.data_directory, line, self.train_data_file)
y_path = os.path.join(self.data_directory, line, self.train_output_targets)
X_path = os.path.join(self.data_directory, line, self.data_file)
y_path = os.path.join(self.data_directory, line, self.output_targets)
self.sample_pairs.append((X_path, y_path))
......@@ -213,27 +212,28 @@ def get_datasets(data_parameters):
data_directory: 'path/to/directory'
train_data_file: 'training_data'
train_output_targets: 'training_targets'
test_data_file: 'testing_data'
test_target_file: 'testing_targets'
train_list = 'train.txt'
validation_list = 'validation.txt'
validation_data_file: 'testing_data'
validation_target_file: 'testing_targets'
}
Returns:
touple: the relevant train and test datasets
"""
training_data = h5py.File(os.path.join(
data_parameters['data_directory'], data_parameters['training_data']), 'r')
testing_data = h5py.File(os.path.join(
data_parameters['data_directory'], data_parameters['testing_data']), 'r')
train_filename = data_parameters['train_list']
data_directory = data_parameters['data_directory']
train_data_file = data_parameters['train_data_file']
train_output_targets = data_parameters['train_output_targets']
training_labels = h5py.File(os.path.join(
data_parameters['data_directory'], data_parameters['training_targets']), 'r')
testing_labels = h5py.File(os.path.join(
data_parameters['data_directory'], data_parameters['testing_targets']), 'r')
validation_filename = data_parameters['validation_list']
validation_data_file = data_parameters['validation_data_file']
validation_output_targets = data_parameters['validation_target_file']
return (
DataMapper(training_data['data'][()], training_labels['label'][()]),
DataMapper(testing_data['data'][()], testing_labels['label'][()])
DataMapper(train_filename, data_directory, train_data_file, train_output_targets),
DataMapper(validation_filename, data_directory, validation_data_file, validation_output_targets)
)
......@@ -508,6 +508,44 @@ class DataMapperHDF5(data.Dataset):
def __len__(self):
return len(self.y)
def get_datasetsHDF5(data_parameters):
"""Data Loader Function.
THIS FUNCTION IS NOT DEPRECATED: Loader function rewritten.
This function loads the various data file and returns the relevand mapped datasets.
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
data_parameters = {
data_directory: 'path/to/directory'
train_data_file: 'training_data'
train_output_targets: 'training_targets'
train_list = 'train.txt'
validation_list = 'validation.txt'
test_list = 'test.txt'
test_data_file: 'testing_data'
test_target_file: 'testing_targets'
}
Returns:
touple: the relevant train and test datasets
"""
training_data = h5py.File(os.path.join(
data_parameters['data_directory'], data_parameters['training_data']), 'r')
testing_data = h5py.File(os.path.join(
data_parameters['data_directory'], data_parameters['testing_data']), 'r')
training_labels = h5py.File(os.path.join(
data_parameters['data_directory'], data_parameters['training_targets']), 'r')
testing_labels = h5py.File(os.path.join(
data_parameters['data_directory'], data_parameters['testing_targets']), 'r')
return (
DataMapperHDF5(training_data['data'][()], training_labels['label'][()]),
DataMapperHDF5(testing_data['data'][()], testing_labels['label'][()])
)
if __name__ == "__main__":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment