Commit 6b18a3e4 authored by Saad Jbabdi's avatar Saad Jbabdi
Browse files

Debugging CellDB

parent 7c535fcb
......@@ -3,50 +3,54 @@
import numpy as np
import os
import sys
import os, sys, types
class CellDB(object):
def __init__(self,images=[], counts=[]):
def __init__(self,images=None,counts=None):
self.images = images
self.cell_counts = counts
if(counts==[]):
self.cell_yn = []
else:
self.cell_yn = self.cell_counts>0
def clear(self):
self.images = []
self.cell_counts = []
self.cell_yn = []
return
def __add__(self,other):
new_images = np.concatenate((self.images,other.images))
new_cell_counts = np.concatenate((self.cell_counts,other.cell_counts))
return CellDB(new_images,new_cell_counts)
# Merge list of celldb objects
def merge(self,dblist):
shape = dblist[0].images.shape[1:]
images = []
counts = []
for x in dblist:
images.append(x.images)
counts.append(x.cell_counts)
if(self.images == []):
self.images = np.array(images).reshape(-1,*shape)
self.cell_counts = np.array(counts)[:]
else:
self.images = np.concatenate(self.images, np.array(images).reshape(-1,*shape))
self.cell_counts = np.concatenate(self.cell_counts, np.array(counts)[:])
def __iadd__(self,other):
self.images = np.concatenate((self.images,other.images))
self.cell_counts = np.concatenate((self.cell_counts,other.cell_counts))
self.cell_yn = self.cell_counts>0
return self
self.summary()
# Save to numpy array
def save(self,outfile,compressed=True):
if(compressed==True):
np.savez(outfile,images=self.images,counts=self.cell_counts)
else:
np.save(outfile,images=self.images,counts=self.cell_counts)
return
# Load from numpy
def load_from_files(self,file_list):
dblist = []
for k,f in enumerate(file_list):
print('load {} into DB'.format(f))
arr = np.load(f)
if(k==0):
celldb = CellDB(arr['images'], arr['counts'])
else:
celldb += CellDB(arr['images'], arr['counts'])
self.images = celldb.images
self.cell_counts = celldb.cell_counts
return
# Ensure all classes are equally represented
# This is not reversible!
def equalise_classes(self, binarise_counts=True):
if(binarise_counts==True):
labels = self.cell_yn
labels = self.cell_counts>0
else:
labels = self.cell_counts
......@@ -66,14 +70,13 @@ class CellDB(object):
self.images = self.images[idx,...]
self.cell_counts = self.cell_counts[idx]
self.cell_yn = self.cell_counts>0
return
# Obtain training/test data split
def split_train_test(self,split=0.1,binarise_counts=True,shuffle=True):
if(binarise_counts==True):
labels = self.cell_yn
labels = self.cell_counts>0
else:
labels = self.cell_counts
......@@ -84,57 +87,39 @@ class CellDB(object):
for cl in classes:
cIdxs = np.where(labels==cl)[0]
if(shuffle==True):
cIdxs = np.random.shuffle(cIdxs)
np.random.shuffle(cIdxs)
n = int((1.0-split)*len(cIdxs))
idx_train.extend(cIdxs[:n])
idx_test.extend(cIdxs[n:])
X_train = self.imaegs[idx_train,...]
X_train = self.images[idx_train,...]
y_train = labels[idx_train]
X_test = self.images[idx_test,...]
y_test = labels[idx_test]
X_test = self.images[idx_test,...]
y_test = labels[idx_test]
return X_train,y_train,X_test,y_test
# Save to numpy array
def save(self,outfile,compressed=True):
if(compressed==True):
np.savez(outfile,images=self.images,counts=self.cell_counts)
else:
np.save(outfile,images=self.images,counts=self.cell_counts)
return
# Load from numpy
def load(self,infiles,overwrite=True):
dblist = []
if(overwrite==True):
self.clear()
for f in infiles:
print('load {} into DB'.format(f))
celldb_np = np.load(f)
celldb = CellDB(celldb_np['images'], celldb_np['counts'])
print('Append')
dblist.append(celldb)
self.merge(dblist)
return
# Summary of the data
def summary(self):
print('-----------------------------------------------')
print('--------- Cell Database Information -----------')
print('-----------------------------------------------')
print('# Images = {}'.format(len(self.images)))
print('# Images with cells = {}'.format(self.cell_yn.sum()))
# Give total counts per class
print('# Cell Counts:')
classes, counts = np.unique(self.cell_counts,return_counts=True)
for k,cl in enumerate(classes):
print('...... class {} : {} elements'.format(cl,counts[k]))
if(self.images is not None):
print('# Images = {}'.format(len(self.images)))
print('# Images with cells = {}'.format((self.cell_counts>0).sum()))
# Give total counts per class
print('# Cell Counts:')
classes, counts = np.unique(self.cell_counts,return_counts=True)
for k,cl in enumerate(classes):
print('...... class {} : {} elements'.format(cl,counts[k]))
else:
print(' ')
print(' ')
print(' EMPTY DATABASE ')
print(' ')
print(' ')
print('-----------------------------------------------')
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment