Commit eab2606e authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

Updates

parent 3e237780
......@@ -43,8 +43,11 @@ def prepare_data(celldb, args):
y_test = np_utils.to_categorical(y_test, n_classes)
# Save mean/std
np.save(os.path.join(args.out,'image_normalise','img_avg.npy'),img_avg)
np.save(os.path.join(args.out,'image_normalise','img_std.npy'),img_std)
normdir = os.path.join(args.out,'image_normalise')
if not os.path.exists(normdir):
os.makedirs(normdir)
np.save(os.path.join(normdir,'img_avg.npy'),img_avg)
np.save(os.path.join(normdir,'img_std.npy'),img_std)
return X_train, y_train, X_test, y_test
......@@ -147,6 +150,8 @@ def train_model(model, celldb, args):
# SAVE MODEL AND FITTING HISTORY
def save_results(model, info, args):
if not os.path.exists(args.out):
os.makedirs(args.out)
outfile = os.path.join(args.out,'model.h5')
model.save(outfile)
......@@ -163,7 +168,7 @@ def save_results(model, info, args):
# - List of DBs
# - Basename folder for output
# - options for the fitting
# - GPU/Augmentation/ModelType?/
# - GPU/Augmentation/ModelType?/train-test split/etc.
# Output
# - model.h5
# - model history
......@@ -181,6 +186,8 @@ def main():
help='number of training epochs (default=100)')
p.add_argument('--batch_size', default=32, type=int, metavar='<int>',
help='batch size (default=32)')
p.add_argument('--split', default=0.1, type=float, metavar='<float>',
help='train/test split (default=0.1)')
p.add_argument('--model', default='convnet', type=str, metavar='<str>',
help='choose model amongst [convet,...] (default=convnet)')
p.add_argument('--augment', default=False, type=bool, metavar='<bool>',
......@@ -197,16 +204,17 @@ def main():
args = p.parse_args()
# Do the work
print('Preparing image database')
print('* Preparing image database')
celldb = db.CellDB()
celldb.load(args.data)
celldb.load_from_files(args.data)
celldb.equalise_classes()
print('Preparing and training model')
model = create_model(args.model)
print('* Preparing and training model')
shape = celldb.images.shape[1:]
model = create_model(shape,args.model)
info = train_model(model, celldb, args)
print('Saving results')
print('* Saving results')
save_results(info, args)
print('Done')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment