Commit b445532d authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

updates and bug fixes

parent b24fa1c2
No preview for this file type
......@@ -3,246 +3,203 @@
# Modified: Saad Jbabdi 09/2018
import os
import sys
import time
# ------------------------------ IMPORT MODULES ------------------------------ #
# General
import matplotlib as mpl
mpl.use('PDF')
import matplotlib.pyplot as plt
import os, sys, time
import numpy as np
from skimage import io as skio
from keras.models import load_model
import collections
import itertools as it
import argparse
# Image stuff
from skimage import io as skio
from sklearn.feature_extraction import image
# DL stuff
from keras.models import load_model
from sklearn.utils import check_array
from sklearn.feature_extraction.image import extract_patches
from sklearn.feature_extraction.image import _compute_n_patches
import argparse
# ------------------------------ DATA ------------------------------ #
def preprocess(imnp, model, args):
# get width and height from model
(w, h, ncols) = model.layers[0].input_shape[1:]
# deal with image edges
imnp=np.pad(imnp,((w,w),(h,h),(0,0)),mode='symmetric')
# Patchify
stride = args.stride
if stride is None:
stride = h
patches = image.extract_patches(imnp, patch_shape = (h, w, ncols), extraction_step = stride)
patches = patches.reshape(-1,h,w,ncols)
# Normalise intensities
patches = patches.astype(np.float32) #/ 255.
img_avg = np.load(os.path.join(args.model_folder,'image_normalise','img_avg.npy'))
img_std = np.load(os.path.join(args.model_folder,'image_normalise','img_std.npy'))
patches = (patches - img_avg)/img_std
def ffm(inimage, inmodel, basename=None, stride=None, gpu=True, timer=True, normdir=None):
print('** {} patches produced'.format(patches.shape[0]))
return patches
def postprocess(patches, model_out, imshape, args):
'''
Fast Forward Model, Oiwi Parker Jones, 2018
density, cells = postprocess(patches, model, args)
'''
if timer == True: total_start = time.time(); load_start = time.time()
print('* Reconstructing output image')
p_h, p_w = patches.shape[1:3]
i_h, i_w = imshape[:2]
print('* Loading data')
imnp = skio.imread(inimage)
density = np.zeros((i_h, i_w), dtype=np.float32)
cells = np.zeros((i_h, i_w), dtype=np.int8)
count = np.zeros((i_h, i_w), dtype=np.uint16)
density = np.pad(density,((p_h,p_h),(p_w,p_w)),mode='symmetric')
cells = np.pad(cells,((p_h,p_h),(p_w,p_w)),mode='symmetric')
count = np.pad(count,((p_h,p_h),(p_w,p_w)),mode='symmetric')
# get padded shape
i_h, i_w = density.shape
if timer == True: load_end = time.time()
# compute the patch indices along each dimension
stride = args.stride
if stride is None:
stride = p_h
h_indices = np.arange(0, i_h - p_h + 1, stride)
w_indices = np.arange(0, i_w - p_w + 1, stride)
# Load model
model = load_model(inmodel)
thr = args.prob_ratio
for k, (p, (i, j)) in enumerate(zip(patches, it.product(h_indices, w_indices))):
if(model_out[k,1]/model_out[k,0]>thr):
density[ i:i + p_h, j:j + p_w] += 1
cells[ i:i + p_h, j:j + p_w] = 1
count[i:i + p_h, j:j + p_w] += 1
# get kernel width and height from model
(k_w, k_h) = model.layers[0].input_shape[1:3]
density = density / (count+(count==0)).astype(np.float32)
# remove padding
density = density[p_h:-p_h,p_w:-p_w]
cells = cells[p_h:-p_h,p_w:-p_w]
# deal with image edges
imnp=np.pad(imnp,((k_w,k_w),(k_h,k_h),(0,0)),mode='symmetric')
if stride==None:
print('** Using default stride = {} (i.e. width of kernel)'.format(k_w))
stride = k_w
if gpu==True:
return density, cells
# ------------------------------ MODEL ------------------------------ #
def forward_prediction(patches, model, args):
if args.gpu:
from keras import backend as K
K.tensorflow_backend._get_available_gpus()
print('* Running forward pass on GPU (CUDA_VISIBLE_DEVICES)')
else:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
print('* Running forward pass on CPU')
# Hopefully this should run on the GPU if available
print(patches.shape)
# Split imnp into set of sub-image patches
print('* Splitting image into {} x {} patches (stride={})'.format(k_w, k_h, stride))
patches = extract_patches_2d_strides(imnp, (k_w, k_h), extraction_step=stride)
print('** {} patches produced'.format(patches.shape[0]))
model_out = model.predict(patches)
# get img_avg and img_std to normalise images (from training set)
if normdir==None:
print("!! normdir==None")
print('** Using model without normalisation!!!')
img_avg = np.zeros((k_w,k_h,3),dtype='float32')
img_std = np.ones((k_w,k_h,3),dtype='float32')
else:
print('** Loading img_avg.npy and img_std.npy for normalisation from: {}'.format(normdir))
img_avg = np.load(os.path.join(normdir,'img_avg.npy'))
img_std = np.load(os.path.join(normdir,'img_std.npy'))
return model_out
if timer == True: forward_start = time.time()
# Do forward pass on each normalised image patch
patches = patches.astype(np.float32) #convert type from uint8 (256 bits) to 32 bit float before normalising
patches = (patches - img_avg) / img_std
patch_out = model.predict(patches)
#patch_out_argmax = patch_out.argmax(axis=1) #get predicted class; 1 = 'cell' (so predictions can be averaged)
# Reconstruct predicted output image, for saving
print('* Reconstructing output image')
recon, _ = reconstruct_from_patches_2d_strides(patches, patch_out, imnp.shape, extraction_step=stride)
def save_results(density, cells, image, args):
# remove padding
recon = recon[k_w:-k_w,k_h:-k_h]
basename = args.out
if basename is None:
basename, _ = os.path.splitext(args.infile)
#save pred image recon
if basename==None:
basename, _ = os.path.splitext(inimage)
outfile = basename+'_stride'+str(stride)+'_density_highres.npz'
outfile = basename+'_density_cells.npz'
#print('* Saving predictions: {}'.format(outfile))
#np.savez_compressed(outfile, recon=recon)
print('* Saving predictions: {}'.format(outfile))
np.savez_compressed(outfile, density=density, cells=cells)
#visualise results
outfile=basename+'_density.png'
print('* Saving snapshot: {}'.format(outfile))
import matplotlib.pyplot as plt
plt.imshow(imnp[k_w:-k_w,k_h:-k_h,:])
plt.imshow(recon,alpha=.2,interpolation='bilinear')
outfile=basename+'_density.png'
print('* Saving snapshot: {}'.format(outfile))
plt.imshow(image)
plt.imshow(density,alpha=.3,interpolation='bilinear')
plt.savefig(outfile, dpi=1000)
if timer == True:
forward_end = time.time()
total_end = time.time()
print('** Time to load data = %.2f seconds' % (load_end - load_start))
print('** Time to run model = %.2f seconds' % (forward_end - forward_start))
print('** Total run time = %.2f seconds' % (total_end - total_start))
def reconstruct_from_patches_2d_strides(patches, patch_out, image_size, extraction_step=1):
"""
Wrapper around sklearn's reconstruct_from_patches_2d, but with added strides.
"""
p_h, p_w = patches.shape[1:3]
i_h, i_w = image_size[:2]
s_h, s_w = [extraction_step,extraction_step]
img = np.zeros((i_h, i_w), dtype=np.float32)
# progressively back off data type for count, depending on amount of kernel overlap
# data types that use less memory trade off with ability to have smaller strides
count = np.zeros((i_h, i_w), dtype=np.uint16)
# compute the patch indices along each dimension
h_indices = np.arange(0, i_h - p_h + 1, s_h)
w_indices = np.arange(0, i_w - p_w + 1, s_w)
thr=2
for k, (p, (i, j)) in enumerate(zip(patches, it.product(h_indices, w_indices))):
if(patch_out[k,1]/patch_out[k,0]>thr):
img[ i:i + p_h, j:j + p_w] += 1
count[i:i + p_h, j:j + p_w] += 1
img = img / (count+(count==0)).astype(np.float32)
return img, count
def squarebox(h,w,rgb):
b = np.zeros((h,w,3))
b[:,0,:] = rgb
b[:,-1,:] = rgb
b[0,:,:] = rgb
b[-1,:,:] = rgb
return b
# display a nxn grid of images
def cell_plot(X,y,n=10):
import matplotlib.pyplot as plt
im_size = X.shape[1]
pad_size = 5
figure = np.zeros((pad_size + (im_size+pad_size) * n,
pad_size + (im_size+pad_size) * n,
3),
dtype='uint8')
cnt = 0
ii = pad_size
for i in range(n):
jj = pad_size
for j in range(n):
if(cnt<X.shape[0]):
figure[ii:ii+im_size,jj:jj+im_size, :] = X[cnt,...]
if(y[cnt,1]>=y[cnt,0]):
rgb = [0,255,0]
else:
rgb = [255,0,0]
figure[ii, jj:jj+im_size,:] = rgb
figure[ii+im_size-1, jj:jj+im_size,:] = rgb
figure[ii:ii+im_size,jj, :] = rgb
figure[ii:ii+im_size,jj+im_size-1, :] = rgb
jj += (im_size+pad_size)
cnt += 1
ii += (im_size+pad_size)
plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()
outfile=basename+'_cells.png'
print('* Saving snapshot: {}'.format(outfile))
plt.imshow(image)
plt.imshow(cells,alpha=.1,interpolation='bilinear')
plt.savefig(outfile, dpi=1000)
#add extraction_step argument for strides
def extract_patches_2d_strides(image, patch_size, extraction_step=None):
'''
Wrapper around extract_patches based on sklearn's extract_patches_2d but with added strides.
Oiwi, 2018
'''
if extraction_step is None:
extraction_step = patch_size #default to nonoverlapping
i_h, i_w = image.shape[:2]
p_h, p_w = patch_size
if p_h > i_h:
raise ValueError("Height of the patch should be less than the height"
" of the image.")
if p_w > i_w:
raise ValueError("Width of the patch should be less than the width"
" of the image.")
image = check_array(image, allow_nd=True)
image = image.reshape((i_h, i_w, -1))
n_colors = image.shape[-1]
return
patches = extract_patches(image,
patch_shape=(p_h, p_w, n_colors),
extraction_step=extraction_step)
# ------------------------------ MAIN ------------------------------ #
patches = patches.reshape(-1, p_h, p_w, n_colors)
# remove the color dimension if useless
if patches.shape[-1] == 1:
return patches.reshape((n_patches, p_h, p_w))
else:
return patches
# TODO:
# output lower res version
# subdivide into patches when the stride is small and do the patches sequentially
def main():
p = argparse.ArgumentParser(description='[ffm] Fast forward model, predicts cell locations in tif image, Oiwi 2018')
p = argparse.ArgumentParser(description='Cell density mapping')
p.add_argument('-o', '--out', default=None, type=str, metavar='<str>',
help='output basename (default derived from infile)')
p.add_argument('-s', '--stride',default=None, type=int, metavar='<int>',
help='stride (default: width of model input filter)')
p.add_argument('-t', '--timer',default=True, type=bool, metavar='<bool>',
help='time how fast the forward model is')
p.add_argument('-n', '--normdir', default=None, type=str, metavar='<dir>',
help='path to directory that the normalising images live in, img_avg.npy and img_std.npy')
help='stride (default: width of model input images)')
p.add_argument('--gpu', default=False, type=bool, metavar='<bool>',
help='use GPU if True (default), use CPU if False')
p.add_argument('--downsample',default=1, type=int, metavar='<int>',
help='downsampling factor (default=1. Set to x to get an output that is 1/x times the input along each dimension)')
p.add_argument('--prob_ratio',default=1.0, type=float, metavar='<float>',
help='probability ratio. cells detected if p(cell)/p(no cell) > prob_ratio.')
required = p.add_argument_group('Required arguments')
required.add_argument('-m', '--inmodel', required=True, type=str, metavar='<str>.h5',
help='model (e.g. convolutional neural network)')
required.add_argument('-m', '--model_folder', required=True, type=str, metavar='<str>',
help='model folder (must contain model.h5 file and image_normalise sub-folder)')
required.add_argument('infile', type=str, metavar='<str>.tif',
help='input image file')
args = p.parse_args()
ffm(args.infile, basename=args.out, inmodel=args.inmodel, stride=args.stride,
gpu=args.gpu, timer=args.timer, normdir=args.normdir)
start_time = time.time();
print('* Load model')
model = load_model(os.path.join(args.model_folder,'model.h5'))
model.summary()
print('* Load Input Image')
imnp = skio.imread(args.infile)
print('* Preprocess')
print('Image size before preproc {}'.format(imnp.shape))
patches = preprocess(imnp, model, args)
print('Image size after preproc {}'.format(imnp.shape))
print(patches.shape)
print('* Forward prediction*')
if(args.downsample != 1):
print(' --> Downsampling has not been implemented just yet.')
# Here: depending on overall size of the image
# Split it into bits, run forward stuff on each bit
# Then downsample each bit and stick it all together
else:
model_out = forward_prediction(patches, model, args)
print('* Post-process')
density, cells = postprocess(patches, model_out, imnp.shape, args)
end_time = time.time()
print('** Total run time = {:.2f} seconds'.format(end_time - start_time))
print('* Save results')
save_results(density, cells, imnp, args)
if __name__ == '__main__':
main()
......
......@@ -28,8 +28,8 @@ def prepare_data(celldb, args):
X_train, y_train, X_test, y_test = celldb.split_train_test(split=args.split)
# Normalise images
#X_train = X_train.astype(np.float32) / 255.0
#X_test = X_test.astype(np.float32) / 255.0
X_train = X_train.astype(np.float32) / 255.0
X_test = X_test.astype(np.float32) / 255.0
img_avg = X_train.mean(axis=0)
img_std = X_train.std(axis=0)
......@@ -123,7 +123,7 @@ def train_model(model, celldb, args):
batch_size = args.batch_size
if(DataAugment):
print('* Using data augmentations')
print('* Using data augmentation')
datagen = ImageDataGenerator(
rotation_range=90, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=.25, # randomly shift images horizontally (fraction of total width)
......@@ -195,7 +195,7 @@ def main():
# Required arguments
required = p.add_argument_group('Required arguments')
required.add_argument('-d', '--data', required=True, type=str, nargs='*', metavar='<str>npz <str>.npz ...',
required.add_argument('-d', '--data', required=True, type=str, nargs='+', metavar='<str>npz',
help='input databases')
required.add_argument('-o', '--out', required=True, type=str, metavar='<str>',
help='output basename')
......@@ -215,7 +215,7 @@ def main():
info = train_model(model, celldb, args)
print('* Saving results')
save_results(info, args)
save_results(model,info, args)
print('Done')
......
#!/usr/bin/env python
import argparse
import jinja2 as j2
import os.path as op
thisdir = op.dirname(__file__)
template_file = op.join(thisdir,'template.j2')
def create_html_from_template(image_list):
with open(template_file,'rt') as f:
template = f.read()
template = j2.Template(template)
html = template.render(images=image_list)
return html
def main():
parser = argparse.ArgumentParser(description='Generate html report of clicked zones')
parser.add_argument("outfile",
help="Output file name", metavar='<file>.html')
parser.add_argument('images',
type=str, nargs='+', metavar='<str>.tiff',
help='List of images')
# Parse arguments
args = parser.parse_args()
# create html from template
content = create_html_from_template(args.images)
with open(args.outfile,"w") as f:
f.write(content)
if __name__ == '__main__':
main()
#!/usr/bin/env python
# Saad Jbabdi 12/2018
# ------------------------------ IMPORT MODULES ------------------------------ #
# General
import numpy as np
import argparse
# Image stuff
from skimage import io as skio
def save_results(density, cells, image, args):
basename = args.out
if basename is None:
basename, _ = os.path.splitext(args.infile)
outfile = basename+'_density_cells.npz'
#visualise results
outfile=basename+'_density.png'
print('* Saving snapshot: {}'.format(outfile))
import matplotlib.pyplot as plt
plt.imshow(image)
plt.imshow(density,alpha=.5,interpolation='bilinear')
plt.savefig(outfile, dpi=1000)
#outfile=basename+'_cells.png'
#print('* Saving snapshot: {}'.format(outfile))
#import matplotlib.pyplot as plt
#plt.imshow(image)
#plt.imshow(cells,alpha=.1,interpolation='bilinear')
#plt.savefig(outfile, dpi=1000)
return
# ------------------------------ MAIN ------------------------------ #
def main():
p = argparse.ArgumentParser(description='Cell density mapping')
p.add_argument('-o', '--out', default=None, type=str, metavar='<str>',
help='output basename (default derived from infile)')
required = p.add_argument_group('Required arguments')
required.add_argument('--density', required=True, type=str, metavar='<str>.npz',
help='input density file in npz format')
required.add_argument('--image', required=True, type=str, metavar='<str>.tiff',
help='input image file in tiff format')
args = p.parse_args()
print('* Load Input Image')
imnp = skio.imread(args.image)
print('* Load Density File')
d=np.load(args.density)
print('* Save into images')
#save_results(d['density'], d['cells'], imnp, args)
#save_results(d['recon'], d['recon'], imnp, args)
save_results(d['density'], d['density'], imnp, args)
if __name__ == '__main__':
main()
<!DOCTYPE html>
<html>
<style>
* {
box-sizing: border-box;
}
body {
margin: 0;
font-family: Arial, Helvetica, sans-serif;
}
.header {
text-align: center;
padding: 32px;
}
.row {
display: -ms-flexbox; /* IE 10 */
display: flex;
-ms-flex-wrap: wrap; /* IE 10 */
flex-wrap: wrap;
padding: 0 4px;
}
/* Create two equal columns that sits next to each other */
.column {
-ms-flex: 50%; /* IE 10 */
flex: 50%;
padding: 0 4px;
}
.column img {
margin-top: 8px;
vertical-align: middle;
}
/* Style the buttons */
.btn {
border: none;
outline: none;