Commit d15c6593 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

bug fixes + updates for HPC deployment

parent b567cb30
......@@ -409,6 +409,7 @@ if __name__ == '__main__':
data_test_train_validation_split(data_parameters['data_folder_name'],
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
data_file=data_parameters['data_file'],
K_fold=data_parameters['k_fold']
)
......
......@@ -168,7 +168,6 @@ class Solver():
if phase == 'train':
model.train()
learning_rate_scheduler.step()
else:
model.eval()
......@@ -234,6 +233,9 @@ class Solver():
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
if phase == 'train':
learning_rate_scheduler.step()
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
# Early Stop Condition
......
......@@ -45,11 +45,12 @@ def directory_reader(folder_location, subject_number=None, write_txt=False):
number_of_subjects = 0
if subject_number is None:
subject_number = len(os.listdir(os.listdir(os.path.join(os.path.expanduser("~"), folder_location))))
subject_number = len(os.listdir(os.path.join(
os.path.expanduser("~"), folder_location)))
for directory in os.listdir(folder_location):
if number_of_subjects < subject_number:
if os.path.isdir(os.path.join(folder_location, directory)):
if os.path.isdir(os.path.join(folder_location, directory)) and os.path.exists(os.path.join(folder_location, directory, "dMRI/autoptx_preproc/")) and os.path.exists(os.path.join(folder_location, directory, "fMRI/rfMRI_25.dr/")):
filename = folder_location+directory
if os.access(filename, os.R_OK):
string = directory
......@@ -64,11 +65,13 @@ def directory_reader(folder_location, subject_number=None, write_txt=False):
return subDirectoryList
def data_file_reader(data_file_path):
def data_file_reader(data_file_path, folder_location, subject_number=None):
"""Data File reader
Args:
data_file_path (str): Path to the file containing the data
folder_location (str): A string containing the address of the required directory.
subject_number (int): Number of subjects to be considered for a job. Useful when wanting to train on datasizes smaller than total datapoints available in a datafolder.
Returns:
subDirectoryList (list): A list of strings containing the available sub-directories
......@@ -78,6 +81,14 @@ def data_file_reader(data_file_path):
subDirectoryList = files.read().split('\n')
subDirectoryList.remove('')
for directory in subDirectoryList:
if os.path.exists(os.path.join(folder_location, directory, "dMRI/autoptx_preproc/")) == False:
if os.path.exists(os.path.join(folder_location, directory, "fMRI/rfMRI_25.dr/")) == False:
subDirectoryList.remove(directory)
if subject_number is not None:
subDirectoryList = subDirectoryList[:subject_number]
return subDirectoryList
......@@ -95,17 +106,12 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_
K_fold (int): Number of folds for splitting the training data
data_file (str): Name of *.txt file containing a list of the required data
Raises:
ValueError: 'Invalid data input! Either a data_file.txt containing all data, or a data_directory string needs to be passed'
"""
if data_file is None:
subDirectoryList = directory_reader(data_directory, subject_number)
elif data_directory is None:
subDirectoryList = data_file_reader(data_file)
if data_file is not None:
subDirectoryList = data_file_reader(data_file, data_directory, subject_number)
else:
raise ValueError(
'Invalid data input! Either a data_file.txt containing all data, or a data_directory string needs to be passed')
subDirectoryList = directory_reader(data_directory, subject_number)
subDirectoryList = np.array(subDirectoryList)
create_folder(data_folder_name)
......@@ -210,9 +216,9 @@ class DataMapper(data.Dataset):
pass
else:
X_path = os.path.join(
self.data_directory, line, self.data_file)
os.path.expanduser("~"), self.data_directory, line, self.data_file)
y_path = os.path.join(
self.data_directory, line, self.output_targets)
os.path.expanduser("~"), self.data_directory, line, self.output_targets)
self.sample_pairs.append((X_path, y_path))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment