Commit d15c6593 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

bug fixes + updates for HPC deployment

parent b567cb30
...@@ -409,6 +409,7 @@ if __name__ == '__main__': ...@@ -409,6 +409,7 @@ if __name__ == '__main__':
data_test_train_validation_split(data_parameters['data_folder_name'], data_test_train_validation_split(data_parameters['data_folder_name'],
data_parameters['test_percentage'], data_parameters['test_percentage'],
data_parameters['subject_number'], data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
data_file=data_parameters['data_file'], data_file=data_parameters['data_file'],
K_fold=data_parameters['k_fold'] K_fold=data_parameters['k_fold']
) )
......
...@@ -168,7 +168,6 @@ class Solver(): ...@@ -168,7 +168,6 @@ class Solver():
if phase == 'train': if phase == 'train':
model.train() model.train()
learning_rate_scheduler.step()
else: else:
model.eval() model.eval()
...@@ -234,6 +233,9 @@ class Solver(): ...@@ -234,6 +233,9 @@ class Solver():
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
) )
if phase == 'train':
learning_rate_scheduler.step()
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs)) print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
# Early Stop Condition # Early Stop Condition
......
...@@ -45,11 +45,12 @@ def directory_reader(folder_location, subject_number=None, write_txt=False): ...@@ -45,11 +45,12 @@ def directory_reader(folder_location, subject_number=None, write_txt=False):
number_of_subjects = 0 number_of_subjects = 0
if subject_number is None: if subject_number is None:
subject_number = len(os.listdir(os.listdir(os.path.join(os.path.expanduser("~"), folder_location)))) subject_number = len(os.listdir(os.path.join(
os.path.expanduser("~"), folder_location)))
for directory in os.listdir(folder_location): for directory in os.listdir(folder_location):
if number_of_subjects < subject_number: if number_of_subjects < subject_number:
if os.path.isdir(os.path.join(folder_location, directory)): if os.path.isdir(os.path.join(folder_location, directory)) and os.path.exists(os.path.join(folder_location, directory, "dMRI/autoptx_preproc/")) and os.path.exists(os.path.join(folder_location, directory, "fMRI/rfMRI_25.dr/")):
filename = folder_location+directory filename = folder_location+directory
if os.access(filename, os.R_OK): if os.access(filename, os.R_OK):
string = directory string = directory
...@@ -64,11 +65,13 @@ def directory_reader(folder_location, subject_number=None, write_txt=False): ...@@ -64,11 +65,13 @@ def directory_reader(folder_location, subject_number=None, write_txt=False):
return subDirectoryList return subDirectoryList
def data_file_reader(data_file_path): def data_file_reader(data_file_path, folder_location, subject_number=None):
"""Data File reader """Data File reader
Args: Args:
data_file_path (str): Path to the file containing the data data_file_path (str): Path to the file containing the data
folder_location (str): A string containing the address of the required directory.
subject_number (int): Number of subjects to be considered for a job. Useful when wanting to train on datasizes smaller than total datapoints available in a datafolder.
Returns: Returns:
subDirectoryList (list): A list of strings containing the available sub-directories subDirectoryList (list): A list of strings containing the available sub-directories
...@@ -78,6 +81,14 @@ def data_file_reader(data_file_path): ...@@ -78,6 +81,14 @@ def data_file_reader(data_file_path):
subDirectoryList = files.read().split('\n') subDirectoryList = files.read().split('\n')
subDirectoryList.remove('') subDirectoryList.remove('')
for directory in subDirectoryList:
if os.path.exists(os.path.join(folder_location, directory, "dMRI/autoptx_preproc/")) == False:
if os.path.exists(os.path.join(folder_location, directory, "fMRI/rfMRI_25.dr/")) == False:
subDirectoryList.remove(directory)
if subject_number is not None:
subDirectoryList = subDirectoryList[:subject_number]
return subDirectoryList return subDirectoryList
...@@ -95,17 +106,12 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_ ...@@ -95,17 +106,12 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_
K_fold (int): Number of folds for splitting the training data K_fold (int): Number of folds for splitting the training data
data_file (str): Name of *.txt file containing a list of the required data data_file (str): Name of *.txt file containing a list of the required data
Raises:
ValueError: 'Invalid data input! Either a data_file.txt containing all data, or a data_directory string needs to be passed'
""" """
if data_file is None: if data_file is not None:
subDirectoryList = directory_reader(data_directory, subject_number) subDirectoryList = data_file_reader(data_file, data_directory, subject_number)
elif data_directory is None:
subDirectoryList = data_file_reader(data_file)
else: else:
raise ValueError( subDirectoryList = directory_reader(data_directory, subject_number)
'Invalid data input! Either a data_file.txt containing all data, or a data_directory string needs to be passed')
subDirectoryList = np.array(subDirectoryList) subDirectoryList = np.array(subDirectoryList)
create_folder(data_folder_name) create_folder(data_folder_name)
...@@ -210,9 +216,9 @@ class DataMapper(data.Dataset): ...@@ -210,9 +216,9 @@ class DataMapper(data.Dataset):
pass pass
else: else:
X_path = os.path.join( X_path = os.path.join(
self.data_directory, line, self.data_file) os.path.expanduser("~"), self.data_directory, line, self.data_file)
y_path = os.path.join( y_path = os.path.join(
self.data_directory, line, self.output_targets) os.path.expanduser("~"), self.data_directory, line, self.output_targets)
self.sample_pairs.append((X_path, y_path)) self.sample_pairs.append((X_path, y_path))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment