Commit a45e7c91 authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

initial commit

parent af175b20
%% Cell type:markdown id: tags:
# Example workflow for training and running model
This Notebook uses pre-cooked celldb databases (created using the click_cells tools in CellCounting). Here we use these databases to train a CNN.
%% Cell type:code id: tags:
``` python
# Do all the imports
# this bit is useful while coding
%load_ext autoreload
%autoreload 2
# standard imports
import os,glob
import numpy as np
# DL imports
from keras.utils import np_utils
from keras.models import Sequential, Model
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers import Activation, Flatten, Dense, BatchNormalization
# CellCounting imports
from CellCounting.models import model_utils as modut
from CellCounting.utils import db
import CellCounting.databases
```
%%%% Output: stream
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
%% Cell type:code id: tags:
``` python
# Load all celldb databases in CellCouting package
file_list = glob.glob(os.path.join(os.path.dirname(CellCounting.databases.__file__),'celldb*.npz'))
celldb = db.CellDB()
celldb.load_from_files(file_list)
celldb.equalise_classes()
celldb.summary()
```
%%%% Output: stream
load /Users/saad/Git/CellCounting/CellCounting/databases/celldb_003.npz into DB
load /Users/saad/Git/CellCounting/CellCounting/databases/celldb_002.npz into DB
load /Users/saad/Git/CellCounting/CellCounting/databases/celldb_001.npz into DB
-----------------------------------------------
--------- Cell Database Information -----------
-----------------------------------------------
# Images = 16834
# Images with cells = 8417
# Cell Counts:
...... class 0 : 8417 elements
...... class 1 : 6745 elements
...... class 2 : 1140 elements
...... class 3 : 339 elements
...... class 4 : 139 elements
...... class 5 : 39 elements
...... class 6 : 14 elements
...... class 8 : 1 elements
-----------------------------------------------
%% Cell type:code id: tags:
``` python
# Split database into train/test data
# Basic preproc : normalise to [0,1], demean, didive by std over train data
X_train, y_train, X_test, y_test = celldb.split_train_test(split=.1)
X_train = X_train.astype(np.float32) / 255.0
X_test = X_test.astype(np.float32) / 255.0
img_avg = X_train.mean(axis=0)
img_std = X_train.std(axis=0)
X_train = (X_train - img_avg) / img_std
X_test = (X_test - img_avg) / img_std
# One-hot encoding
n_classes = len(np.unique(y_train))
y_train = np_utils.to_categorical(y_train, n_classes)
y_test = np_utils.to_categorical(y_test, n_classes)
```
%% Cell type:code id: tags:
``` python
# Create a simple CNN for classification into Cell / No cell
shape = X_train.shape[1:]
model = Sequential(name='SimpleConvNet')
# First block of conv+maxpool+batchnorm+nonlinearity
model.add(Convolution2D(20, (5, 5), strides=(1, 1), padding='valid', input_shape=shape, name='conv1'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=2, padding='valid', name='maxpool1'))
model.add(BatchNormalization(name='batchnorm1'))
model.add(Activation('relu',name='relu1'))
# Second block of conv+maxpool+batchnorm+nonlinearity
model.add(Convolution2D(50, (3, 3), strides=(1, 1), padding='valid', name='conv2'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=4, padding='valid', name='maxpool2'))
model.add(BatchNormalization(name='batchnorm2'))
model.add(Activation('relu',name='relu2'))
# Third block of conv+nonlinearity
model.add(Convolution2D(500, (3, 3), strides=(1, 1), padding='valid', name='conv3'))
model.add(Activation('relu',name='relu3'))
# Last bit is dense N->2 for classification
model.add(Flatten(name='flatten'))
model.add(Dense(2, name='dense'))
model.add(Activation('softmax', name='softmax'))
model.summary()
```
%%%% Output: stream
Model: "SimpleConvNet"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv1 (Conv2D) (None, 60, 60, 20) 1520
_________________________________________________________________
maxpool1 (MaxPooling2D) (None, 30, 30, 20) 0
_________________________________________________________________
batchnorm1 (BatchNormalizati (None, 30, 30, 20) 80
_________________________________________________________________
relu1 (Activation) (None, 30, 30, 20) 0
_________________________________________________________________
conv2 (Conv2D) (None, 28, 28, 50) 9050
_________________________________________________________________
maxpool2 (MaxPooling2D) (None, 7, 7, 50) 0
_________________________________________________________________
batchnorm2 (BatchNormalizati (None, 7, 7, 50) 200
_________________________________________________________________
relu2 (Activation) (None, 7, 7, 50) 0
_________________________________________________________________
conv3 (Conv2D) (None, 5, 5, 500) 225500
_________________________________________________________________
relu3 (Activation) (None, 5, 5, 500) 0
_________________________________________________________________
flatten (Flatten) (None, 12500) 0
_________________________________________________________________
dense (Dense) (None, 2) 25002
_________________________________________________________________
softmax (Activation) (None, 2) 0
=================================================================
Total params: 261,352
Trainable params: 261,212
Non-trainable params: 140
_________________________________________________________________
%% Cell type:code id: tags:
``` python
# Train model
# (maybe offer option to load pre-trained model)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
TrainingArgs = {'gpu':False,'augment':True,'verbose':True,'epochs':20}
if True:
info = modut.train_model(model = model,
data = [X_train,y_train,X_test,y_test],
**TrainingArgs)
```
%%%% Output: stream
* Running forward pass on CPU
* Using data augmentation
Epoch 1/20
474/474 [==============================] - 41s 87ms/step - loss: 0.3534 - accuracy: 0.8600 - val_loss: 0.4066 - val_accuracy: 0.8325
Epoch 2/20
474/474 [==============================] - 41s 86ms/step - loss: 0.3399 - accuracy: 0.8671 - val_loss: 0.3545 - val_accuracy: 0.8278
Epoch 3/20
474/474 [==============================] - 41s 87ms/step - loss: 0.3324 - accuracy: 0.8741 - val_loss: 0.3186 - val_accuracy: 0.8901
Epoch 4/20
474/474 [==============================] - 41s 86ms/step - loss: 0.3343 - accuracy: 0.8708 - val_loss: 0.2656 - val_accuracy: 0.8979
Epoch 5/20
474/474 [==============================] - 43s 90ms/step - loss: 0.3302 - accuracy: 0.8730 - val_loss: 0.2649 - val_accuracy: 0.9056
Epoch 6/20
474/474 [==============================] - 44s 93ms/step - loss: 0.3283 - accuracy: 0.8725 - val_loss: 0.2676 - val_accuracy: 0.8973
Epoch 7/20
474/474 [==============================] - 51s 108ms/step - loss: 0.3122 - accuracy: 0.8804 - val_loss: 0.3086 - val_accuracy: 0.8741
Epoch 8/20
474/474 [==============================] - 54s 114ms/step - loss: 0.3154 - accuracy: 0.8796 - val_loss: 0.2695 - val_accuracy: 0.8985
Epoch 9/20
474/474 [==============================] - 60s 127ms/step - loss: 0.3168 - accuracy: 0.8779 - val_loss: 0.2730 - val_accuracy: 0.9097
Epoch 10/20
474/474 [==============================] - 57s 121ms/step - loss: 0.3003 - accuracy: 0.8857 - val_loss: 0.2481 - val_accuracy: 0.9157
Epoch 11/20
474/474 [==============================] - 58s 122ms/step - loss: 0.3088 - accuracy: 0.8814 - val_loss: 0.2739 - val_accuracy: 0.9186
Epoch 12/20
474/474 [==============================] - 59s 125ms/step - loss: 0.3012 - accuracy: 0.8840 - val_loss: 0.2390 - val_accuracy: 0.9109
Epoch 13/20
474/474 [==============================] - 59s 124ms/step - loss: 0.3017 - accuracy: 0.8849 - val_loss: 0.2655 - val_accuracy: 0.9068
Epoch 14/20
474/474 [==============================] - 59s 125ms/step - loss: 0.3050 - accuracy: 0.8845 - val_loss: 0.4332 - val_accuracy: 0.8432
Epoch 15/20
474/474 [==============================] - 59s 124ms/step - loss: 0.3064 - accuracy: 0.8826 - val_loss: 0.3057 - val_accuracy: 0.8913
Epoch 16/20
474/474 [==============================] - 62s 131ms/step - loss: 0.2981 - accuracy: 0.8859 - val_loss: 0.2428 - val_accuracy: 0.9246
Epoch 17/20
474/474 [==============================] - 65s 137ms/step - loss: 0.2966 - accuracy: 0.8891 - val_loss: 0.2339 - val_accuracy: 0.9210
Epoch 18/20
474/474 [==============================] - 62s 131ms/step - loss: 0.2913 - accuracy: 0.8892 - val_loss: 0.2674 - val_accuracy: 0.8913
Epoch 19/20
474/474 [==============================] - 68s 144ms/step - loss: 0.2939 - accuracy: 0.8916 - val_loss: 0.2333 - val_accuracy: 0.9121
Epoch 20/20
474/474 [==============================] - 64s 135ms/step - loss: 0.2960 - accuracy: 0.8853 - val_loss: 0.2616 - val_accuracy: 0.9074
Model took 1089.87 seconds to train
%% Cell type:code id: tags:
``` python
# Load pre-trained network and look at performance measures
from keras.models import load_model
filename = '/Users/saad/Desktop/test_090220/model.h5'
model = load_model(filename)
model.summary()
avg = np.load('/Users/saad/Desktop/test_090220/image_normalise/img_avg.npy')
std = np.load('/Users/saad/Desktop/test_090220/image_normalise/img_std.npy')
```
%% Cell type:code id: tags:
``` python
import keras
keras.__version__
```
%% Cell type:code id: tags:
``` python
celldb = db.CellDB()
file_list = ['/Users/saad/Git/CellCounting/celldb_001.npz','/Users/saad/Git/CellCounting/celldb_002.npz','/Users/saad/Git/CellCounting/celldb_003.npz']
celldb.load_from_files(file_list)
```
%% Cell type:code id: tags:
``` python
pred = model.predict((celldb.images/255.0-avg)/std)
```
%% Cell type:code id: tags:
``` python
# Plot random square of images with predictions and circle in red if prediction does not match target?
import matplotlib.pyplot as plt
import numpy as np
plt.plot(celldb.cell_counts,pred[:,1],'.')
plt.show()
```
%% Cell type:code id: tags:
``` python
print(np.sum( pred[:,1]>.95))
print(np.sum( celldb.cell_counts>=1))
```
%% Cell type:code id: tags:
``` python
from keras.models import Model
grot = Model(model.input, model.get_layer(name='relu3').output)
i = np.random.randint(0,celldb.images.shape[0])
im = np.squeeze(celldb.images[i,:,:,:])
imn = (im[None,:,:,:]/255-avg)/std
x = grot.predict(imn)
plt.plot(x.flatten())
```
%% Cell type:code id: tags:
``` python
import pandas as pd
filename = '/Users/saad/Desktop/test_090220/model_hist.csv'
df = pd.read_csv(filename)
plt.plot(df['loss'])
plt.plot(df['val_loss'])
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0,10,100)
y = np.sin(3*x) + np.random.randn(x.size)*.5
sig = 2
# def rbf(x,c):
# return np.exp(-(x-c)**2/sig**2)
rbf = lambda x,c : np.exp(-(x-c)**2/sig**2)
xi = np.linspace(0,10,15)
desmat = [rbf(x,c) for c in xi]
desmat = np.asarray(desmat).T
beta = np.linalg.pinv(desmat)@y.T
plt.plot(x,y,'.')
plt.plot(x,desmat,'k')
plt.plot(x,desmat@beta)
plt.grid()
plt.xlabel('x')
plt.ylabel('y')
plt.title('RBF fitting')
plt.savefig('/Users/saad/Desktop/RBF.pdf')
t
```
%% Cell type:code id: tags:
``` python
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment