Commit 35db07b6 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added AdamW optimiser

parent f5eb75d8
......@@ -148,12 +148,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(
training_parameters['pre_trained_path'])
else:
else:
BrainMapperModel = BrainMapperAE3D(network_parameters)
BrainMapperModel.reset_parameters()
optimizer = torch.optim.Adam
# optimizer = torch.optim.AdamW
solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'],
......@@ -253,28 +254,27 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
# TODO - NEED TO UPDATE THE DATA FUNCTIONS!
logWriter = LogWriter(number_of_classes=network_parameters['number_of_classes'],
logs_directory=misc_parameters['logs_directory'],
experiment_name=training_parameters['experiment_name']
)
prediction_output_path = os.path.join(misc_parameters['experiments_directory'],
training_parameters['experiment_name'],
evaluation_parameters['saved_predictions_directory']
)
_ = evaluations.evaluate_dice_score(trained_model_path=evaluation_parameters['trained_model_path'],
number_of_classes=network_parameters['number_of_classes'],
data_directory=evaluation_parameters['data_directory'],
targets_directory=evaluation_parameters['targets_directory'],
data_list=evaluation_parameters['data_list'],
orientation=evaluation_parameters['orientation'],
prediction_output_path=prediction_output_path,
device=misc_parameters['device'],
LogWriter=logWriter
)
logWriter.close()
evaluations.evaluate_correlation(trained_model_path=evaluation_parameters['trained_model_path'],
data_directory=evaluation_parameters['data_directory'],
mapping_data_file=mapping_evaluation_parameters['mapping_data_file'],
target_data_file=evaluation_parameters['targets_directory'],
data_list=evaluation_parameters['data_list'],
prediction_output_path=prediction_output_path,
brain_mask_path=mapping_evaluation_parameters['brain_mask_path'],
rsfmri_mean_mask_path=mapping_evaluation_parameters[
'rsfmri_mean_mask_path'],
dmri_mean_mask_path=mapping_evaluation_parameters[
'dmri_mean_mask_path'],
mean_reduction=mapping_evaluation_parameters['mean_reduction'],
scaling_factors=mapping_evaluation_parameters['scaling_factors'],
regression_factors=mapping_evaluation_parameters['regression_factors'],
device=misc_parameters['device'],
)
def evaluate_mapping(mapping_evaluation_parameters):
......@@ -310,18 +310,18 @@ def evaluate_mapping(mapping_evaluation_parameters):
regression_factors = mapping_evaluation_parameters['regression_factors']
evaluations.evaluate_mapping(trained_model_path,
data_directory,
mapping_data_file,
data_list,
prediction_output_path,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
mean_reduction,
scaling_factors,
regression_factors,
device=device,
exit_on_error=exit_on_error)
data_directory,
mapping_data_file,
data_list,
prediction_output_path,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
mean_reduction,
scaling_factors,
regression_factors,
device=device,
exit_on_error=exit_on_error)
def delete_files(folder):
......@@ -368,27 +368,27 @@ if __name__ == '__main__':
if data_parameters['use_data_file'] == True:
data_preparation(data_parameters['data_folder_name'],
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
data_file=data_parameters['data_file'],
K_fold=data_parameters['k_fold']
)
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
data_file=data_parameters['data_file'],
K_fold=data_parameters['k_fold']
)
else:
data_preparation(data_parameters['data_folder_name'],
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
K_fold=data_parameters['k_fold']
)
data_parameters['test_percentage'],
data_parameters['subject_number'],
data_directory=data_parameters['data_directory'],
train_inputs=data_parameters['train_data_file'],
train_targets=data_parameters['train_output_targets'],
rsfMRI_mean_mask_path=data_parameters['rsfmri_mean_mask_path'],
dMRI_mean_mask_path=data_parameters['dmri_mean_mask_path'],
K_fold=data_parameters['k_fold']
)
update_shuffling_flag('settings.ini')
print('Data is shuffling... Complete!')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment