Commit e615a4c6 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added flags for network transf blocks & weight init + ask for network name for custom settings

parent a897de36
......@@ -134,7 +134,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
dataset=train_data,
batch_size=training_parameters['training_batch_size'],
shuffle=True,
num_workers=4,
pin_memory=True
)
......@@ -142,7 +141,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
dataset=validation_data,
batch_size=training_parameters['validation_batch_size'],
shuffle=False,
num_workers=4,
pin_memory=True
)
......@@ -152,7 +150,9 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
else:
BrainMapperModel = BrainMapperAE3D(network_parameters)
BrainMapperModel.reset_parameters()
custom_weight_reset_flag = network_parameters['custom_weight_reset_flag']
BrainMapperModel.reset_parameters(custom_weight_reset_flag)
optimizer = torch.optim.Adam
# optimizer = torch.optim.AdamW
......@@ -280,12 +280,15 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', '-m', required=True,
help='run mode, valid values are train or evaluate')
parser.add_argument('--settings_path', '-sp', required=False,
help='optional argument, set path to settings_evaluation.ini')
parser.add_argument('--model_name', '-n', required=True,
help='model name, required for identifying the settings file modelName.ini & modelName_eval.ini')
arguments = parser.parse_args()
settings = Settings('settings.ini')
settings_file_name = arguments.model_name + '.ini'
evaluation_settings_file_name = arguments.model_name + '_eval.ini'
settings = Settings(settings_file_name)
data_parameters = settings['DATA']
training_parameters = settings['TRAINING']
network_parameters = settings['NETWORK']
......@@ -300,10 +303,7 @@ if __name__ == '__main__':
elif arguments.mode == 'evaluate-mapping':
logging.basicConfig(filename='evaluate-mapping-error.log')
if arguments.settings_path is not None:
settings_evaluation = Settings(arguments.settings_path)
else:
settings_evaluation = Settings('settings_evaluation.ini')
settings_evaluation = Settings(evaluation_settings_file_name)
mapping_evaluation_parameters = settings_evaluation['MAPPING']
evaluate_mapping(mapping_evaluation_parameters)
elif arguments.mode == 'clear-experiments':
......@@ -317,10 +317,7 @@ if __name__ == '__main__':
delete_files(misc_parameters['logs_directory'])
print('Cleared the current experiments and logs directory successfully!')
elif arguments.mode == 'train-and-evaluate-mapping':
if arguments.settings_path is not None:
settings_evaluation = Settings(arguments.settings_path)
else:
settings_evaluation = Settings('settings_evaluation.ini')
settings_evaluation = Settings(evaluation_settings_file_name)
mapping_evaluation_parameters = settings_evaluation['MAPPING']
train(data_parameters, training_parameters,
network_parameters, misc_parameters)
......
......@@ -379,7 +379,7 @@ class Solver():
self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
for state in self.optimizer.state. ():
for state in self.optimizer.state.values():
for key, value in state.items():
if torch.is_tensor(value):
state[key] = value.to(self.device)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment