Commit 39f00dc3 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added Autoencoder3D import + commented out L1 loss for x2x

parent 2413e84e
......@@ -39,7 +39,7 @@ import torch.utils.data as data
import numpy as np
from solver import Solver
from BrainMapperAE import BrainMapperAE3D
from BrainMapperAE import BrainMapperAE3D, AutoEncoder3D
from utils.data_utils import get_datasets
from utils.settings import Settings
import utils.data_evaluation_utils as evaluations
......@@ -195,7 +195,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
else:
BrainMapperModel = BrainMapperAE3D(network_parameters)
# BrainMapperModel = BrainMapperAE3D(network_parameters)
BrainMapperModel = AutoEncoder3D(network_parameters) # temprorary change for testing encoder-decoder effective receptive field
custom_weight_reset_flag = network_parameters['custom_weight_reset_flag']
......@@ -265,7 +266,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
data_parameters['target_data_train'] = data_parameters['input_data_train']
data_parameters['target_data_validation'] = data_parameters['input_data_validation']
loss_function = torch.nn.L1Loss()
# loss_function = torch.nn.L1Loss()
_ = _train_runner(data_parameters,
training_parameters,
......@@ -281,7 +282,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
data_parameters['input_data_train'] = data_parameters['target_data_train']
data_parameters['input_data_validation'] = data_parameters['target_data_validation']
loss_function = torch.nn.L1Loss()
# loss_function = torch.nn.L1Loss()
_ = _train_runner(data_parameters,
training_parameters,
......@@ -366,6 +367,7 @@ def evaluate_mapping(mapping_evaluation_parameters):
device,
exit_on_error)
def delete_files(folder):
""" Clear Folder Contents
......@@ -422,8 +424,6 @@ if __name__ == '__main__':
train(data_parameters, training_parameters,
network_parameters, misc_parameters)
# NOTE: THE EVAL FUNCTIONS HAVE NOT YET BEEN DEBUGGED (16/04/20)
elif arguments.mode == 'evaluate-mapping':
logging.basicConfig(filename='evaluate-mapping-error.log')
settings_evaluation = Settings(evaluation_settings_file_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment