Commit cb210299 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added function for test/train/val split

parent 5713426b
......@@ -20,9 +20,9 @@ import nibabel as nib
import torch
import torch.utils.data as data
import nibabel as nb
import random
def directory_reader(folder_location):
def directory_reader(folder_location, write_txt=False):
"""Produces a list of of data-tags which are accessible
This function looks in a large data directory, and returns a list of sub-directories which are accessible.
......@@ -30,7 +30,7 @@ def directory_reader(folder_location):
Args:
folder_location (str): A string containing the address of the required directory.
write_txt (bool): Flag indicating if a .txt file should be created.
Returns:
A list of strings containing the available sub-directories. This is also printed out as a .txt file
"""
......@@ -43,15 +43,45 @@ def directory_reader(folder_location):
filename = folder_location+directory
if os.access(filename, os.R_OK):
string = directory+'\n'
out_file.write(string)
if write_txt == True:
out_file.write(string)
subDirectoryList.append(directory)
return subDirectoryList
def data_test_train_validation_split(folder_location, train_percentage, validation_percentage):
"""Produces lists of train, test and validation data
This function looks at the list of all available directories and returns three lists of dsub-directories.
These lists are the lists required for training, testing and validation.
Args:
folder_location (str): A string containing the address of the required directory.
train_percentage (int): Percentage of data to be used for training
validation_percentage (int): Percentage of data to be used for validation
"""
subDirectoryList = directory_reader(folder_location)
random.shuffle(subDirectoryList)
subDirectoryList = np.array(subDirectoryList)
train, validation, test = np.split(subDirectoryList, [int(train_percentage/100 * len(subDirectoryList)), int((train_percentage+validation_percentage)/100 * len(subDirectoryList))])
np.savetxt('train.txt', train)
np.savetxt('test.txt', test)
np.savetxt('validation.txt', validation)
def update_shuffling_flag(file_name):
pass
def tract_sum_generator(folder_path):
"""Sums the tracts of different dMRI files
THIS FUNCTION IS NOT DEPRECATED: SummedTractMaps generated remotely
When performing subject-specific probabilistic diffusion tractography using standard-space protocols, 27 tracts are created.
This function loops through all the tracts, sums them and returns the summed tract map.
This function also outputs the summed tract map as a Nifti (.nii.gz) file.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment