Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
d3b668b0
Commit
d3b668b0
authored
Apr 09, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
refactored train variable names to validation
parent
98ccda23
Changes
2
Hide whitespace changes
Inline
Side-by-side
run.py
View file @
d3b668b0
...
...
@@ -52,7 +52,6 @@ def load_data(data_parameters):
"""Dataset Loader
This function loads the training and validation datasets.
TODO: Will need to define if all the training data is loaded as bulk or individually!
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
...
...
@@ -86,7 +85,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
training_parameters(dict): Dictionary containing relevant hyperparameters for training the network.
training_parameters = {
'training_batch_size': 5
'
test
_batch_size: 5
'
validation
_batch_size: 5
'use_pre_trained': False
'pre_trained_path': 'pre_trained/path'
'experiment_name': 'experiment_name'
...
...
@@ -114,7 +113,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
}
"""
train_data
,
test
_data
=
load_data
(
data_parameters
)
train_data
,
validation
_data
=
load_data
(
data_parameters
)
train_loader
=
data
.
DataLoader
(
dataset
=
train_data
,
...
...
@@ -124,9 +123,9 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
pin_memory
=
True
)
test
_loader
=
data
.
DataLoader
(
dataset
=
test
_data
,
batch_size
=
training_parameters
[
'
test
_batch_size'
],
validation
_loader
=
data
.
DataLoader
(
dataset
=
validation
_data
,
batch_size
=
training_parameters
[
'
validation
_batch_size'
],
shuffle
=
False
,
num_workers
=
4
,
pin_memory
=
True
...
...
@@ -157,7 +156,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
logs_directory
=
misc_parameters
[
'logs_directory'
]
)
solver
.
train
(
train_loader
,
test
_loader
)
solver
.
train
(
train_loader
,
validation
_loader
)
model_output_path
=
os
.
path
.
join
(
misc_parameters
[
'save_model_directory'
],
training_parameters
[
'final_model_output_file'
])
...
...
settings.ini
View file @
d3b668b0
...
...
@@ -13,7 +13,7 @@ validation_target_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
[TRAINING]
training_batch_size
=
5
test
_batch_size
=
5
validation
_batch_size
=
5
use_pre_trained
=
False
pre_trained_path
=
"saved_models/preTrained.pth.tar"
experiment_name
=
"experiment_name"
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment