Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Saad Jbabdi
CellCounting
Commits
eab2606e
Commit
eab2606e
authored
Sep 20, 2018
by
Saad Jbabdi
Browse files
Updates
parent
3e237780
Changes
1
Hide whitespace changes
Inline
Side-by-side
CellCounting/models/train_model.py
View file @
eab2606e
...
...
@@ -43,8 +43,11 @@ def prepare_data(celldb, args):
y_test
=
np_utils
.
to_categorical
(
y_test
,
n_classes
)
# Save mean/std
np
.
save
(
os
.
path
.
join
(
args
.
out
,
'image_normalise'
,
'img_avg.npy'
),
img_avg
)
np
.
save
(
os
.
path
.
join
(
args
.
out
,
'image_normalise'
,
'img_std.npy'
),
img_std
)
normdir
=
os
.
path
.
join
(
args
.
out
,
'image_normalise'
)
if
not
os
.
path
.
exists
(
normdir
):
os
.
makedirs
(
normdir
)
np
.
save
(
os
.
path
.
join
(
normdir
,
'img_avg.npy'
),
img_avg
)
np
.
save
(
os
.
path
.
join
(
normdir
,
'img_std.npy'
),
img_std
)
return
X_train
,
y_train
,
X_test
,
y_test
...
...
@@ -147,6 +150,8 @@ def train_model(model, celldb, args):
# SAVE MODEL AND FITTING HISTORY
def
save_results
(
model
,
info
,
args
):
if
not
os
.
path
.
exists
(
args
.
out
):
os
.
makedirs
(
args
.
out
)
outfile
=
os
.
path
.
join
(
args
.
out
,
'model.h5'
)
model
.
save
(
outfile
)
...
...
@@ -163,7 +168,7 @@ def save_results(model, info, args):
# - List of DBs
# - Basename folder for output
# - options for the fitting
# - GPU/Augmentation/ModelType?/
# - GPU/Augmentation/ModelType?/
train-test split/etc.
# Output
# - model.h5
# - model history
...
...
@@ -181,6 +186,8 @@ def main():
help
=
'number of training epochs (default=100)'
)
p
.
add_argument
(
'--batch_size'
,
default
=
32
,
type
=
int
,
metavar
=
'<int>'
,
help
=
'batch size (default=32)'
)
p
.
add_argument
(
'--split'
,
default
=
0.1
,
type
=
float
,
metavar
=
'<float>'
,
help
=
'train/test split (default=0.1)'
)
p
.
add_argument
(
'--model'
,
default
=
'convnet'
,
type
=
str
,
metavar
=
'<str>'
,
help
=
'choose model amongst [convet,...] (default=convnet)'
)
p
.
add_argument
(
'--augment'
,
default
=
False
,
type
=
bool
,
metavar
=
'<bool>'
,
...
...
@@ -197,16 +204,17 @@ def main():
args
=
p
.
parse_args
()
# Do the work
print
(
'Preparing image database'
)
print
(
'
*
Preparing image database'
)
celldb
=
db
.
CellDB
()
celldb
.
load
(
args
.
data
)
celldb
.
load
_from_files
(
args
.
data
)
celldb
.
equalise_classes
()
print
(
'Preparing and training model'
)
model
=
create_model
(
args
.
model
)
print
(
'* Preparing and training model'
)
shape
=
celldb
.
images
.
shape
[
1
:]
model
=
create_model
(
shape
,
args
.
model
)
info
=
train_model
(
model
,
celldb
,
args
)
print
(
'Saving results'
)
print
(
'
*
Saving results'
)
save_results
(
info
,
args
)
print
(
'Done'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment