Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
39f00dc3
Commit
39f00dc3
authored
Aug 19, 2020
by
Andrei Roibu
Browse files
added Autoencoder3D import + commented out L1 loss for x2x
parent
2413e84e
Changes
1
Hide whitespace changes
Inline
Side-by-side
run.py
View file @
39f00dc3
...
...
@@ -39,7 +39,7 @@ import torch.utils.data as data
import
numpy
as
np
from
solver
import
Solver
from
BrainMapperAE
import
BrainMapperAE3D
from
BrainMapperAE
import
BrainMapperAE3D
,
AutoEncoder3D
from
utils.data_utils
import
get_datasets
from
utils.settings
import
Settings
import
utils.data_evaluation_utils
as
evaluations
...
...
@@ -195,7 +195,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
if
training_parameters
[
'use_pre_trained'
]:
BrainMapperModel
=
torch
.
load
(
training_parameters
[
'pre_trained_path'
])
else
:
BrainMapperModel
=
BrainMapperAE3D
(
network_parameters
)
# BrainMapperModel = BrainMapperAE3D(network_parameters)
BrainMapperModel
=
AutoEncoder3D
(
network_parameters
)
# temprorary change for testing encoder-decoder effective receptive field
custom_weight_reset_flag
=
network_parameters
[
'custom_weight_reset_flag'
]
...
...
@@ -265,7 +266,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
data_parameters
[
'target_data_train'
]
=
data_parameters
[
'input_data_train'
]
data_parameters
[
'target_data_validation'
]
=
data_parameters
[
'input_data_validation'
]
loss_function
=
torch
.
nn
.
L1Loss
()
#
loss_function = torch.nn.L1Loss()
_
=
_train_runner
(
data_parameters
,
training_parameters
,
...
...
@@ -281,7 +282,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
data_parameters
[
'input_data_train'
]
=
data_parameters
[
'target_data_train'
]
data_parameters
[
'input_data_validation'
]
=
data_parameters
[
'target_data_validation'
]
loss_function
=
torch
.
nn
.
L1Loss
()
#
loss_function = torch.nn.L1Loss()
_
=
_train_runner
(
data_parameters
,
training_parameters
,
...
...
@@ -366,6 +367,7 @@ def evaluate_mapping(mapping_evaluation_parameters):
device
,
exit_on_error
)
def
delete_files
(
folder
):
""" Clear Folder Contents
...
...
@@ -422,8 +424,6 @@ if __name__ == '__main__':
train
(
data_parameters
,
training_parameters
,
network_parameters
,
misc_parameters
)
# NOTE: THE EVAL FUNCTIONS HAVE NOT YET BEEN DEBUGGED (16/04/20)
elif
arguments
.
mode
==
'evaluate-mapping'
:
logging
.
basicConfig
(
filename
=
'evaluate-mapping-error.log'
)
settings_evaluation
=
Settings
(
evaluation_settings_file_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment