Commit 9a8106ff authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

updates

parent b445532d
......@@ -16,7 +16,9 @@ import itertools as it
import argparse
# Image stuff
from skimage import io as skio
from sklearn.feature_extraction import image
from sklearn.feature_extraction import image
from skimage.transform import rescale, resize, downscale_local_mean
# DL stuff
from keras.models import load_model
......@@ -34,7 +36,7 @@ def preprocess(imnp, model, args):
patches = image.extract_patches(imnp, patch_shape = (h, w, ncols), extraction_step = stride)
patches = patches.reshape(-1,h,w,ncols)
# Normalise intensities
patches = patches.astype(np.float32) #/ 255.
patches = patches.astype(np.float32) / 255.0
img_avg = np.load(os.path.join(args.model_folder,'image_normalise','img_avg.npy'))
img_std = np.load(os.path.join(args.model_folder,'image_normalise','img_std.npy'))
patches = (patches - img_avg)/img_std
......@@ -107,7 +109,7 @@ def forward_prediction(patches, model, args):
def save_results(density, cells, image, args):
def save_results(density, cells, image, patch_shape, args):
basename = args.out
if basename is None:
......@@ -116,10 +118,14 @@ def save_results(density, cells, image, args):
outfile = basename+'_density_cells.npz'
print('* Saving predictions: {}'.format(outfile))
np.savez_compressed(outfile, density=density, cells=cells)
# Estimate number of cells in image
ncells = density.mean()*image.size/patch_shape[1]/patch_shape[2]
print('* Detected {} Cells'.format(ncells))
np.savez_compressed(outfile, density=density, cells=cells, ncells=ncells)
#visualise results
outfile=basename+'_density.png'
print('* Saving snapshot: {}'.format(outfile))
plt.imshow(image)
......@@ -186,7 +192,7 @@ def main():
# Here: depending on overall size of the image
# Split it into bits, run forward stuff on each bit
# Then downsample each bit and stick it all together
model_out = forward_prediction(patches, model, args)
else:
model_out = forward_prediction(patches, model, args)
......@@ -198,7 +204,7 @@ def main():
print('** Total run time = {:.2f} seconds'.format(end_time - start_time))
print('* Save results')
save_results(density, cells, imnp, args)
save_results(density, cells, imnp, patches.shape, args)
if __name__ == '__main__':
......
......@@ -136,6 +136,7 @@ def train_model(model, celldb, args):
epochs = epochs, verbose=1,
shuffle=True, validation_data = (X_test, y_test))
else:
print('* NOT using data augmentation')
info = model.fit(X_train,y_train, batch_size = batch_size,
epochs = epochs, verbose=1,
shuffle=True, validation_data=(X_test,y_test))
......@@ -190,7 +191,9 @@ def main():
help='train/test split (default=0.1)')
p.add_argument('--model', default='convnet', type=str, metavar='<str>',
help='choose model amongst [convet,...] (default=convnet)')
p.add_argument('--augment', default=False, type=bool, metavar='<bool>',
p.add_argument('--load_model', default=None, type=str, metavar='<str>.h5',
help='load pretrained model')
p.add_argument('--augment', default=False, type=lambda s: s.lower() in ['true', 't', 'yes', '1'], metavar='<bool>',
help='use data augmentation (default=False)')
# Required arguments
......@@ -211,7 +214,12 @@ def main():
print('* Preparing and training model')
shape = celldb.images.shape[1:]
model = create_model(shape,args.model)
if args.load_model is not None:
print('** Loading pretrained model')
from keras.models import load_model
model = load_model(args.load_model)
else:
model = create_model(shape,args.model)
info = train_model(model, celldb, args)
print('* Saving results')
......
......@@ -3,17 +3,33 @@
import argparse
import jinja2 as j2
import os.path as op
from numpy import around, load, uint8
thisdir = op.dirname(__file__)
template_file = op.join(thisdir,'template.j2')
def get_info(basenames):
image_list = []
cell_counts = []
for b in basenames:
print(b)
if op.isfile(b+"_density.jpg"):
image_list.append(b+"_density.jpg")
else:
image_list.append(b+"_density.png")
x = load(b+"_density_cells.npz",mmap_mode='r')
ncells = around(x['ncells']).astype(uint8)
cell_counts.append(ncells)
def create_html_from_template(image_list):
return image_list, cell_counts
def create_html_from_template(image_list, cell_counts):
with open(template_file,'rt') as f:
template = f.read()
template = j2.Template(template)
html = template.render(images=image_list)
images_ncells = zip(image_list, cell_counts)
html = template.render(images_ncells=images_ncells)
return html
def main():
......@@ -22,15 +38,19 @@ def main():
parser.add_argument("outfile",
help="Output file name", metavar='<file>.html')
parser.add_argument('images',
type=str, nargs='+', metavar='<str>.tiff',
help='List of images')
parser.add_argument('basenames',
type=str, nargs='+', metavar='<str>',
help='List of basenames (will look for <str>_density_cells.npz and <str>_density.png/jpg')
# Parse arguments
args = parser.parse_args()
# Get list of images and ncells
image_list, cell_counts = get_info(args.basenames)
# create html from template
content = create_html_from_template(args.images)
content = create_html_from_template(image_list,cell_counts)
with open(args.outfile,"w") as f:
f.write(content)
......
......@@ -36,7 +36,7 @@ def main():
# Optional arguments
# Required arguments
p.add_argument('image',
type=str, nargs='*', metavar='<str>.tiff [<str>.tiff ...] ',
type=str, nargs='+', metavar='<str>.tiff [<str>.tiff ...] ',
help='input image or list of images')
p.add_argument('--width',action='store_true',
default=False, dest='boolean_w',
......
......@@ -68,10 +68,11 @@ body {
</div>
<!-- Photo Grid -->
<div class="row">
{% for i in images %}
<div class="column">
<img src="{{ i }}" style="width:100%">
<div class="row" style="text-align:center;border:1px solid black">
{% for img,n in images_ncells %}
<div class="column" style="text-align:center;border:1px solid black">
<p> <font size="1" color="red"> ~{{ n }} cells</font>
<img src="{{ img }}" style="width:100%">
</div>
{% endfor %}
</div>
......
Cell Counting code
Cell Counting Project
-- Install Instructions
git clone https://git.fmrib.ox.ac.uk/saad/CellCounting.git
cd CellCounting
pip install .
-- Howto?
-- ClickCells
Code for creating training data
-- select_zones.py
-- split_zones.py
-- click_cells.py
-- Models
-- sequential_cnn.py
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment