Commit 9d940084 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

eliminated val loss return, added new run argument modes

parent 4810e6e6
......@@ -232,12 +232,14 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
crop_flag = data_parameters['crop_flag']
)
validation_loss = solver.train(train_loader, validation_loader)
# _ = solver.train(train_loader, validation_loader)
solver.train(train_loader, validation_loader)
del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver, optimizer
torch.cuda.empty_cache()
return validation_loss
# return None
if training_parameters['adam_w_flag'] == True:
......@@ -250,13 +252,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
# loss_function=torch.nn.CosineEmbeddingLoss()
if network_parameters['cross_domain_flag'] == False:
_ = _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
_train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
elif network_parameters['cross_domain_flag'] == True:
if network_parameters['cross_domain_x2x_flag'] == True:
......@@ -267,13 +269,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
# loss_function = torch.nn.L1Loss()
_ = _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
_train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
if network_parameters['cross_domain_y2y_flag'] == True:
......@@ -283,23 +285,23 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
# loss_function = torch.nn.L1Loss()
_ = _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
_train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
if network_parameters['cross_domain_x2y_flag'] == True:
_ = _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
_train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
def evaluate_mapping(mapping_evaluation_parameters):
......@@ -428,16 +430,43 @@ if __name__ == '__main__':
settings_evaluation = Settings(evaluation_settings_file_name)
mapping_evaluation_parameters = settings_evaluation['MAPPING']
evaluate_mapping(mapping_evaluation_parameters)
elif arguments.mode == 'clear-experiments':
shutil.rmtree(os.path.join(
misc_parameters['experiments_directory'], training_parameters['experiment_name']))
shutil.rmtree(os.path.join(
misc_parameters['logs_directory'], training_parameters['experiment_name']))
print('Cleared the current experiments and logs directory successfully!')
elif arguments.mode == 'clear-everything':
delete_files(misc_parameters['experiments_directory'])
delete_files(misc_parameters['logs_directory'])
print('Cleared the current experiments and logs directory successfully!')
elif arguments.mode == 'clear-checkpoints':
if network_parameters['cross_domain_flag'] == True:
if network_parameters['cross_domain_x2x_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x"
if network_parameters['cross_domain_y2y_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y"
shutil.rmtree(os.path.join(misc_parameters['experiments_directory'], training_parameters['experiment_name']))
print('Cleared the current experiment checkpoints successfully!')
elif arguments.mode == 'clear-logs':
if network_parameters['cross_domain_flag'] == True:
if network_parameters['cross_domain_x2x_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x"
if network_parameters['cross_domain_y2y_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y"
shutil.rmtree(os.path.join(misc_parameters['logs_directory'], training_parameters['experiment_name']))
print('Cleared the current experiment logs directory successfully!')
elif arguments.mode == 'clear-experiment':
if network_parameters['cross_domain_flag'] == True:
if network_parameters['cross_domain_x2x_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x"
if network_parameters['cross_domain_y2y_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y"
shutil.rmtree(os.path.join(misc_parameters['experiments_directory'], training_parameters['experiment_name']))
shutil.rmtree(os.path.join(misc_parameters['logs_directory'], training_parameters['experiment_name']))
print('Cleared the current experiment checkpoints and logs directory successfully!')
# elif arguments.mode == 'clear-everything':
# delete_files(misc_parameters['experiments_directory'])
# delete_files(misc_parameters['logs_directory'])
# print('Cleared the all the checkpoints and logs directory successfully!')
elif arguments.mode == 'train-and-evaluate-mapping':
settings_evaluation = Settings(evaluation_settings_file_name)
mapping_evaluation_parameters = settings_evaluation['MAPPING']
......@@ -445,6 +474,7 @@ if __name__ == '__main__':
network_parameters, misc_parameters)
logging.basicConfig(filename='evaluate-mapping-error.log')
evaluate_mapping(mapping_evaluation_parameters)
else:
raise ValueError(
'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, clear-experiments and clear-everything')
'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, train-and-evaluate-mapping, clear-checkpoints, clear-logs, clear-experiment and clear-everything (req uncomment for safety!)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment