Commit 72461c75 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added ability to use a txt file as input

parent 4d2e9395
......@@ -352,14 +352,24 @@ if __name__ == '__main__':
misc_parameters = settings['MISC']
evaluation_parameters = settings['EVALUATION']
data_split_flag = data_parameters['data_split_flag']
if data_split_flag == True:
# Here we shuffle the data!
data_test_train_validation_split(
data_parameters['data_directory'], data_parameters['train_percentage'], data_parameters['validation_percentage'], data_parameters['subject_number'])
# Here we shuffle the data!
if data_parameters['data_split_flag'] == True:
if data_parameters['use_data_file'] == True:
data_test_train_validation_split(data_parameters['data_folder_name'],
data_parameters['train_percentage'],
data_parameters['validation_percentage'],
data_parameters['subject_number'],
data_file= data_parameters['data_file'])
else:
data_test_train_validation_split(data_parameters['data_folder_name'],
data_parameters['train_percentage'],
data_parameters['validation_percentage'],
data_parameters['subject_number'],
data_directory= data_parameters['data_directory'])
update_shuffling_flag('settings.ini')
# TODO: This might also be a very good point to add cross-validation later
if arguments.mode == 'train':
train(data_parameters, training_parameters,
......
[DATA]
data_folder_name = "datasets"
data_directory = "../well/win-biobank/projects/imaging/data/data3/subjectsAll/"
data_split_flag = False
train_percentage = 90
......@@ -11,6 +12,8 @@ train_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
train_output_targets = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
validation_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
validation_target_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_file = "../well/win-biobank/projects/imaging/data/data3/subjectsAll/subj_22k.txt"
use_data_file = False
[TRAINING]
training_batch_size = 2
......@@ -28,6 +31,7 @@ learning_rate_scheduler_step_size = 3
learning_rate_scheduler_gamma = 1e-1
use_last_checkpoint = False
final_model_output_file = "finetuned_alldata.pth.tar"
cross_validation = False
[NETWORK]
kernel_heigth = 5
......
......@@ -36,7 +36,7 @@ def directory_reader(folder_location, subject_number=None, write_txt=False):
write_txt (bool): Flag indicating if a .txt file should be created.
suject_number (int): Number of subjects to be considered for a job. Useful when wanting to train on datasizes smaller than total datapoints available in a datafolder.
Returns:
A list of strings containing the available sub-directories. This is also printed out as a .txt file
subDirectoryList (list): A list of strings containing the available sub-directories. This is also printed out as a .txt file
"""
out_file = open("files.txt", 'w')
......@@ -64,21 +64,48 @@ def directory_reader(folder_location, subject_number=None, write_txt=False):
return subDirectoryList
def data_test_train_validation_split(folder_location, train_percentage, validation_percentage, subject_number):
def data_file_reader(data_file_path):
"""Data File reader
Args:
data_file_path (str): Path to the file containing the data
Returns:
subDirectoryList (list): A list of strings containing the available sub-directories
"""
with open(data_file_path) as files:
subDirectoryList = files.read().split('\n')
subDirectoryList.remove('')
return subDirectoryList
def data_test_train_validation_split(data_folder_name, train_percentage, validation_percentage, subject_number, data_directory = None, data_file = None):
"""Produces lists of train, test and validation data
This function looks at the list of all available directories and returns three lists of dsub-directories.
These lists are the lists required for training, testing and validation.
Args:
folder_location (str): A string containing the address of the required directory.
data_folder_name (str): The name of the folder where the string data is being output
data_directory (str): A string containing the address of the required directory.
train_percentage (int): Percentage of data to be used for training
validation_percentage (int): Percentage of data to be used for validation
suject_number (int): Number of subjects to be considered for a job. Useful when wanting to train on datasizes smaller than total datapoints available in a datafolder.
data_file (str): Name of *.txt file containing a list of the required data
Raises:
ValueError: 'Invalid data input! Either a data_file.txt containing all data, or a data_directory string needs to be passed'
"""
subDirectoryList = directory_reader(folder_location, subject_number)
if data_file is None:
subDirectoryList = directory_reader(data_directory, subject_number)
elif data_directory is None:
subDirectoryList = data_file_reader(data_file)
else:
raise ValueError(
'Invalid data input! Either a data_file.txt containing all data, or a data_directory string needs to be passed')
random.shuffle(subDirectoryList)
......@@ -87,11 +114,50 @@ def data_test_train_validation_split(folder_location, train_percentage, validati
train, validation, test = np.split(subDirectoryList, [int(train_percentage/100 * len(
subDirectoryList)), int((train_percentage+validation_percentage)/100 * len(subDirectoryList))])
np.savetxt('train.txt', train, fmt='%s')
np.savetxt('test.txt', test, fmt='%s')
np.savetxt('validation.txt', validation, fmt='%s')
create_folder(data_folder_name)
np.savetxt(os.path.join(data_folder_name, 'train.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'test.txt'), test, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation.txt'), validation, fmt='%s')
def data_test_train_validation_split_Kfold_cross_validation(data_folder_name, K_fold, subject_number, data_directory = None, data_file = None):
"""Produces lists of train, test and validation data
This function looks at the list of all available directories and returns three lists of dsub-directories.
These lists are the lists required for training, testing and validation.
Args:
data_directory (str): A string containing the address of the required directory.
train_percentage (int): Percentage of data to be used for training
validation_percentage (int): Percentage of data to be used for validation
suject_number (int): Number of subjects to be considered for a job. Useful when wanting to train on datasizes smaller than total datapoints available in a datafolder.
Raises:
ValueError: 'Invalid data input! Either a data_file.txt containing all data, or a data_directory string needs to be passed'
"""
if data_file is None:
subDirectoryList = directory_reader(data_directory, subject_number)
elif data_directory is None:
subDirectoryList = data_file_reader(data_file)
else:
raise ValueError(
'Invalid data input! Either a data_file.txt containing all data, or a data_directory string needs to be passed')
random.shuffle(subDirectoryList)
subDirectoryList = np.array(subDirectoryList)
train, validation, test = np.split(subDirectoryList, [int(train_percentage/100 * len(
subDirectoryList)), int((train_percentage+validation_percentage)/100 * len(subDirectoryList))])
create_folder(data_folder_name)
np.savetxt(os.path.join(data_folder_name, 'train.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'test.txt'), test, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation.txt'), validation, fmt='%s')
def update_shuffling_flag(file_name):
""" Update shuffling flag
......@@ -571,3 +637,29 @@ def get_datasetsHDF5(data_parameters):
training_labels['label'][()]),
DataMapperHDF5(testing_data['data'][()], testing_labels['label'][()])
)
if __name__ == '__main__':
data_file_path = 'train.txt'
subDirectoryList = data_file_reader(data_file_path)
print(subDirectoryList)
print(type(subDirectoryList))
folder_location = "../well/win-biobank/projects/imaging/data/data3/subjectsAll/"
subDirectoryList2 = directory_reader(folder_location)
print(subDirectoryList2)
print(type(subDirectoryList2))
data_folder_name = "datasets"
train_percentage = 90
validation_percentage = 5
subject_number = None
data_directory = "../well/win-biobank/projects/imaging/data/data3/subjectsAll/"
data_test_train_validation_split(data_folder_name, train_percentage, validation_percentage, subject_number, data_directory = data_directory)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment