Commit 9d940084 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

eliminated val loss return, added new run argument modes

parent 4810e6e6
...@@ -232,12 +232,14 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -232,12 +232,14 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
crop_flag = data_parameters['crop_flag'] crop_flag = data_parameters['crop_flag']
) )
validation_loss = solver.train(train_loader, validation_loader) # _ = solver.train(train_loader, validation_loader)
solver.train(train_loader, validation_loader)
del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver, optimizer del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver, optimizer
torch.cuda.empty_cache() torch.cuda.empty_cache()
return validation_loss # return None
if training_parameters['adam_w_flag'] == True: if training_parameters['adam_w_flag'] == True:
...@@ -250,13 +252,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -250,13 +252,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
# loss_function=torch.nn.CosineEmbeddingLoss() # loss_function=torch.nn.CosineEmbeddingLoss()
if network_parameters['cross_domain_flag'] == False: if network_parameters['cross_domain_flag'] == False:
_ = _train_runner(data_parameters, _train_runner(data_parameters,
training_parameters, training_parameters,
network_parameters, network_parameters,
misc_parameters, misc_parameters,
optimizer=optimizer, optimizer=optimizer,
loss_function=loss_function loss_function=loss_function
) )
elif network_parameters['cross_domain_flag'] == True: elif network_parameters['cross_domain_flag'] == True:
if network_parameters['cross_domain_x2x_flag'] == True: if network_parameters['cross_domain_x2x_flag'] == True:
...@@ -267,13 +269,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -267,13 +269,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
# loss_function = torch.nn.L1Loss() # loss_function = torch.nn.L1Loss()
_ = _train_runner(data_parameters, _train_runner(data_parameters,
training_parameters, training_parameters,
network_parameters, network_parameters,
misc_parameters, misc_parameters,
optimizer=optimizer, optimizer=optimizer,
loss_function=loss_function loss_function=loss_function
) )
if network_parameters['cross_domain_y2y_flag'] == True: if network_parameters['cross_domain_y2y_flag'] == True:
...@@ -283,23 +285,23 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -283,23 +285,23 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
# loss_function = torch.nn.L1Loss() # loss_function = torch.nn.L1Loss()
_ = _train_runner(data_parameters, _train_runner(data_parameters,
training_parameters, training_parameters,
network_parameters, network_parameters,
misc_parameters, misc_parameters,
optimizer=optimizer, optimizer=optimizer,
loss_function=loss_function loss_function=loss_function
) )
if network_parameters['cross_domain_x2y_flag'] == True: if network_parameters['cross_domain_x2y_flag'] == True:
_ = _train_runner(data_parameters, _train_runner(data_parameters,
training_parameters, training_parameters,
network_parameters, network_parameters,
misc_parameters, misc_parameters,
optimizer=optimizer, optimizer=optimizer,
loss_function=loss_function loss_function=loss_function
) )
def evaluate_mapping(mapping_evaluation_parameters): def evaluate_mapping(mapping_evaluation_parameters):
...@@ -428,16 +430,43 @@ if __name__ == '__main__': ...@@ -428,16 +430,43 @@ if __name__ == '__main__':
settings_evaluation = Settings(evaluation_settings_file_name) settings_evaluation = Settings(evaluation_settings_file_name)
mapping_evaluation_parameters = settings_evaluation['MAPPING'] mapping_evaluation_parameters = settings_evaluation['MAPPING']
evaluate_mapping(mapping_evaluation_parameters) evaluate_mapping(mapping_evaluation_parameters)
elif arguments.mode == 'clear-experiments':
shutil.rmtree(os.path.join( elif arguments.mode == 'clear-checkpoints':
misc_parameters['experiments_directory'], training_parameters['experiment_name'])) if network_parameters['cross_domain_flag'] == True:
shutil.rmtree(os.path.join( if network_parameters['cross_domain_x2x_flag'] == True:
misc_parameters['logs_directory'], training_parameters['experiment_name'])) training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x"
print('Cleared the current experiments and logs directory successfully!') if network_parameters['cross_domain_y2y_flag'] == True:
elif arguments.mode == 'clear-everything': training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y"
delete_files(misc_parameters['experiments_directory'])
delete_files(misc_parameters['logs_directory']) shutil.rmtree(os.path.join(misc_parameters['experiments_directory'], training_parameters['experiment_name']))
print('Cleared the current experiments and logs directory successfully!') print('Cleared the current experiment checkpoints successfully!')
elif arguments.mode == 'clear-logs':
if network_parameters['cross_domain_flag'] == True:
if network_parameters['cross_domain_x2x_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x"
if network_parameters['cross_domain_y2y_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y"
shutil.rmtree(os.path.join(misc_parameters['logs_directory'], training_parameters['experiment_name']))
print('Cleared the current experiment logs directory successfully!')
elif arguments.mode == 'clear-experiment':
if network_parameters['cross_domain_flag'] == True:
if network_parameters['cross_domain_x2x_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x"
if network_parameters['cross_domain_y2y_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y"
shutil.rmtree(os.path.join(misc_parameters['experiments_directory'], training_parameters['experiment_name']))
shutil.rmtree(os.path.join(misc_parameters['logs_directory'], training_parameters['experiment_name']))
print('Cleared the current experiment checkpoints and logs directory successfully!')
# elif arguments.mode == 'clear-everything':
# delete_files(misc_parameters['experiments_directory'])
# delete_files(misc_parameters['logs_directory'])
# print('Cleared the all the checkpoints and logs directory successfully!')
elif arguments.mode == 'train-and-evaluate-mapping': elif arguments.mode == 'train-and-evaluate-mapping':
settings_evaluation = Settings(evaluation_settings_file_name) settings_evaluation = Settings(evaluation_settings_file_name)
mapping_evaluation_parameters = settings_evaluation['MAPPING'] mapping_evaluation_parameters = settings_evaluation['MAPPING']
...@@ -445,6 +474,7 @@ if __name__ == '__main__': ...@@ -445,6 +474,7 @@ if __name__ == '__main__':
network_parameters, misc_parameters) network_parameters, misc_parameters)
logging.basicConfig(filename='evaluate-mapping-error.log') logging.basicConfig(filename='evaluate-mapping-error.log')
evaluate_mapping(mapping_evaluation_parameters) evaluate_mapping(mapping_evaluation_parameters)
else: else:
raise ValueError( raise ValueError(
'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, clear-experiments and clear-everything') 'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, clear-checkpoints, clear-logs, clear-experiment and clear-everything (req uncomment for safety!)')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment