Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Saad Jbabdi
CellCounting
Commits
da66c46a
Commit
da66c46a
authored
Sep 15, 2018
by
Saad Jbabdi
Browse files
Conform to python standards
parent
f927924d
Changes
42
Hide whitespace changes
Inline
Side-by-side
C
lickCells
/__init__.py
→
C
ellCounting
/__init__.py
View file @
da66c46a
File moved
CellCounting/click_cells/#create_db.py#
0 → 100755
View file @
da66c46a
#!/usr/bin/env python3
import
argparse
import
numpy
as
np
import
pandas
as
pd
import
re
from
CellCounting.Utils.db
import
DataBase
DB_IMAGE_RES
=
64
def
check_imshape
(
shape
):
sx
,
sy
,
_
=
shape
if
sx
!=
sy
:
return
False
if
(
sx
%
DB_IMAGE_RES
!=
0
)
or
(
sy
%
DB_IMAGE_RES
!=
0
):
return
False
return
True
def
append_file_content
(
fname
,
image_list
,
count_list
):
df
=
pd
.
read_table
(
fname
)
udf
=
df
.
groupby
(
'Sub-Image-File'
).
count
()
for
f
in
udf
.
index
:
# Load Numpy array
im
=
np
.
load
(
f
.
strip
())
if
not
check_imshape
(
im
.
shape
):
print
(
"Error: Bad Image dimensions. Must be square and multiple of {}"
.
format
(
DB_IMAGE_RES
))
sizx
,
sizy
,
_
=
im
.
shape
# Split into sub-zones
size_ratio
=
sizx
//
DB_IMAGE_RES
im2
=
im
.
reshape
(
size_ratio
,
sizx
//
size_ratio
,
size_ratio
,
sizy
//
size_ratio
,
3
)
im3
=
im2
.
transpose
(
0
,
2
,
1
,
3
,
4
).
reshape
(
size_ratio
**
2
,
sizx
//
size_ratio
,
sizy
//
size_ratio
,
3
)
image_list
.
append
(
im3
)
res
=
re
.
findall
(
"w_(\d+).(\d+)_h_(\d+).(\d+)"
,
f
)[
0
]
W
=
round
(
float
(
res
[
0
]
+
"."
+
res
[
1
]))
H
=
round
(
float
(
res
[
2
]
+
"."
+
res
[
3
]))
count_cells
=
np
.
zeros
((
size_ratio
,
size_ratio
),
dtype
=
int
)
for
indiv_cells
in
df
.
values
[
df
[
'Sub-Image-File'
]
==
f
]:
if
np
.
isnan
(
indiv_cells
[
1
]):
pass
else
:
w
=
float
(
indiv_cells
[
1
])
-
W
h
=
float
(
indiv_cells
[
2
])
-
H
count_cells
[
int
(
w
//
(
sizx
/
size_ratio
)),
int
(
h
//
(
sizy
/
size_ratio
))]
+=
1
count_list
.
append
(
count_cells
.
flatten
())
return
len
(
udf
.
index
)
def
create_db
(
file_list
,
outfile
):
image_list
=
[]
count_list
=
[]
total
=
0
for
f
in
file_list
:
total
+=
append_file_content
(
f
,
image_list
,
count_list
)
shape
=
image_list
[
0
].
shape
[
1
:]
count_list
=
np
.
array
(
count_list
).
flatten
()
np
.
savez
(
outfile
,
counts
=
count_list
,
images
=
np
.
array
(
image_list
).
reshape
(
-
1
,
*
shape
))
print
(
"Created DB with {} Images from a list of {} with {} containing cells."
.
format
(
len
(
count_list
),
total
,(
count_list
>
0
).
sum
()))
def
main
():
# Parse command line arguments
parser
=
argparse
.
ArgumentParser
(
"Create DB from clicked textfiles"
)
parser
.
add_argument
(
"outfile"
,
help
=
"Output file name"
)
parser
.
add_argument
(
"file"
,
help
=
"Clicky text file"
,
nargs
=
'+'
)
args
=
parser
.
parse_args
()
create_db
(
args
.
file
,
args
.
outfile
)
if
__name__
==
'__main__'
:
main
()
CellCounting/click_cells/.#create_db.py
0 → 120000
View file @
da66c46a
saad
@
jalapeno00
.
fmrib
.
ox
.
ac
.
uk
.
24091
:
1532960442
\ No newline at end of file
Click
C
ells/.DS_Store
→
C
ellCounting/c
lick
_c
ells/.DS_Store
View file @
da66c46a
File moved
Click
C
ells/README
→
C
ellCounting/c
lick
_c
ells/README
View file @
da66c46a
File moved
__init__.py
→
CellCounting/click_cells/
__init__.py
View file @
da66c46a
File moved
Click
C
ells/click_cells.py
→
C
ellCounting/c
lick
_c
ells/click_cells.py
View file @
da66c46a
...
...
@@ -66,6 +66,8 @@ def main():
help
=
"Output file name."
)
parser
.
add_argument
(
"--shuffle"
,
action
=
'store_true'
,
default
=
False
,
dest
=
'shuffle'
,
help
=
"Load sub-images in random order."
)
parser
.
add_argument
(
"--append"
,
action
=
'store_true'
,
default
=
False
,
dest
=
'append'
,
help
=
"Append results to output file."
)
parser
.
add_argument
(
"--empty_zone"
,
action
=
'store_true'
,
default
=
False
,
dest
=
"empty_zone"
,
help
=
"Entire zone is empty"
)
args
=
parser
.
parse_args
()
...
...
@@ -79,16 +81,19 @@ def main():
random
.
shuffle
(
files
)
create_header
=
True
if
op
.
exists
(
outfile
):
print
(
"File {} exists. Overwrite/Append/Exit?[O,A,E]"
.
format
(
outfile
))
response
=
input
()
if
response
.
upper
()
==
"O"
:
os
.
remove
(
outfile
)
elif
response
.
upper
()
==
"E"
:
print
(
"Exiting without doing anything"
)
exit
()
elif
response
.
upper
()
==
"A"
:
if
op
.
exists
(
outfile
):
if
args
.
append
==
True
:
create_header
=
False
else
:
print
(
"File {} exists. Overwrite/Append/Exit?[O,A,E]"
.
format
(
outfile
))
response
=
input
()
if
response
.
upper
()
==
"O"
:
os
.
remove
(
outfile
)
elif
response
.
upper
()
==
"E"
:
print
(
"Exiting without doing anything"
)
exit
()
elif
response
.
upper
()
==
"A"
:
create_header
=
False
if
create_header
==
True
:
with
open
(
outfile
,
'w'
)
as
f
:
...
...
Click
C
ells/create_db.py
→
C
ellCounting/c
lick
_c
ells/create_db.py
View file @
da66c46a
...
...
@@ -66,7 +66,7 @@ def create_db(file_list,outfile):
print
(
"Created DB with {} Images from a list of {}."
.
format
(
len
(
count_list
),
total
))
print
(
"Created DB with {} Images from a list of {}
with {} containing cells
."
.
format
(
len
(
count_list
),
total
,(
count_list
>
0
).
sum
()
))
def
main
():
...
...
@@ -76,13 +76,13 @@ def main():
)
parser
.
add_argument
(
"outfile"
,
help
=
"Output file name"
)
parser
.
add_argument
(
"file
_list
"
,
help
=
"
List of c
licky text file
s
"
,
parser
.
add_argument
(
"file"
,
help
=
"
C
licky text file"
,
nargs
=
'+'
)
args
=
parser
.
parse_args
()
create_db
(
args
.
file
_list
,
args
.
outfile
)
create_db
(
args
.
file
,
args
.
outfile
)
if
__name__
==
'__main__'
:
main
()
CellCounting/click_cells/select_zones.py
0 → 100755
View file @
da66c46a
#!/usr/bin/env python
import
matplotlib
as
mpl
import
sys
if
(
sys
.
platform
==
'darwin'
):
mpl
.
use
(
'wxagg'
)
from
matplotlib.widgets
import
RectangleSelector
,
Slider
from
matplotlib.image
import
AxesImage
import
numpy
as
np
import
os
import
os.path
as
op
import
shutil
import
matplotlib.pyplot
as
plt
import
argparse
import
sys
import
time
# use this for BigTiff
import
tifffile
as
tif
from
skimage
import
io
as
skio
# pilfered from https://stackoverflow.com/a/19306776
def
get_ax_size
(
fig
,
ax
):
bbox
=
ax
.
get_window_extent
().
transformed
(
fig
.
dpi_scale_trans
.
inverted
())
width
,
height
=
bbox
.
width
,
bbox
.
height
width
*=
fig
.
dpi
height
*=
fig
.
dpi
return
width
,
height
class
ZoneSelector
(
object
):
def
__init__
(
self
,
ax
,
outbase
,
img
):
self
.
box
=
[]
self
.
axis
=
ax
self
.
zone_number
=
0
self
.
image
=
img
self
.
outbase
=
outbase
self
.
rs
=
RectangleSelector
(
self
.
axis
,
self
.
select_callback
,
drawtype
=
'box'
,
rectprops
=
dict
(
facecolor
=
'orange'
,
edgecolor
=
'red'
,
alpha
=
0.2
,
fill
=
True
))
def
select_callback
(
self
,
eclick
,
erelease
):
print
(
'down'
,
eclick
)
print
(
'up '
,
erelease
)
x1
,
y1
=
eclick
.
xdata
,
eclick
.
ydata
x2
,
y2
=
erelease
.
xdata
,
erelease
.
ydata
y1
=
1
-
y1
y2
=
1
-
y2
h
,
w
=
self
.
image
.
shape
[:
2
]
zone
=
np
.
array
((
np
.
min
([
x1
,
x2
])
*
w
,
np
.
min
([
y1
,
y2
])
*
h
,
np
.
max
([
x1
,
x2
])
*
w
,
np
.
max
([
y1
,
y2
])
*
h
)).
astype
(
'i4'
)
x1
,
y1
,
x2
,
y2
=
zone
crop
=
self
.
image
[
y1
:
y2
,
x1
:
x2
,
:]
tif
.
imsave
(
self
.
outbase
+
"_zone{:03d}.tiff"
.
format
(
self
.
zone_number
),
crop
)
#crop.save(self.outbase + "_zone{:03d}.tiff".format(self.zone_number), "TIFF")
with
open
(
self
.
outbase
+
"_zone{:03d}.txt"
.
format
(
self
.
zone_number
),
"w"
)
as
fcoord
:
fcoord
.
write
(
'{} {} {} {}'
.
format
(
*
zone
))
self
.
zone_number
+=
1
class
BigImageViewer
(
object
):
def
__init__
(
self
,
fig
,
ax
,
data
):
self
.
fig
=
fig
self
.
ax
=
ax
self
.
data
=
data
self
.
ax_img
=
None
ax
.
set_autoscale_on
(
False
)
self
.
vmin_ax
=
fig
.
add_axes
([
0.2
,
0.1
,
0.6
,
0.03
])
self
.
vmax_ax
=
fig
.
add_axes
([
0.2
,
0.05
,
0.6
,
0.03
])
self
.
vmin_ctrl
=
Slider
(
self
.
vmin_ax
,
'vmin'
,
0
,
255
,
0
)
self
.
vmax_ctrl
=
Slider
(
self
.
vmax_ax
,
'vmax'
,
0
,
255
,
255
)
self
.
vmin_ctrl
.
on_changed
(
self
.
refresh
)
self
.
vmax_ctrl
.
on_changed
(
self
.
refresh
)
self
.
ax
.
callbacks
.
connect
(
'xlim_changed'
,
self
.
refresh
)
self
.
ax
.
callbacks
.
connect
(
'ylim_changed'
,
self
.
refresh
)
self
.
refresh
()
def
get_data
(
self
,
cw
,
ch
,
xmin
,
xmax
,
ymin
,
ymax
):
dh
,
dw
=
self
.
data
.
shape
[:
2
]
dws
=
xmin
*
dw
dwe
=
xmax
*
dw
dhs
=
(
1
-
ymin
)
*
dh
dhe
=
(
1
-
ymax
)
*
dh
dhs
,
dhe
=
sorted
((
dhs
,
dhe
))
dwlen
=
dwe
-
dws
dhlen
=
dhe
-
dhs
ratew
=
max
(
1
,
int
(
float
(
dwlen
)
/
cw
))
rateh
=
max
(
1
,
int
(
float
(
dhlen
)
/
ch
))
dws
,
dwe
=
np
.
clip
((
dws
,
dwe
),
0
,
dw
).
astype
(
np
.
int
)
dhs
,
dhe
=
np
.
clip
((
dhs
,
dhe
),
0
,
dh
).
astype
(
np
.
int
)
return
self
.
data
[
dhs
:
dhe
:
rateh
,
dws
:
dwe
:
ratew
,
:],
(
dws
,
dwe
,
dhs
,
dhe
)
def
refresh
(
self
,
*
a
):
fig
=
self
.
fig
ax
=
self
.
ax
vmin
=
self
.
vmin_ctrl
.
val
vmax
=
self
.
vmax_ctrl
.
val
w
,
h
=
get_ax_size
(
fig
,
ax
)
bounds
=
ax
.
viewLim
.
bounds
xmin
,
ymin
,
xlen
,
ylen
=
bounds
xmax
=
xmin
+
xlen
ymax
=
ymin
+
ylen
data
,
extent
=
self
.
get_data
(
w
,
h
,
xmin
,
xmax
,
ymin
,
ymax
)
data
=
data
.
astype
(
np
.
float32
)
data
=
np
.
clip
(
data
,
vmin
,
vmax
)
data
=
255
*
(
data
-
vmin
)
/
(
vmax
-
vmin
)
data
=
data
.
astype
(
np
.
uint8
)
# fix aspect ratio
wl
,
wh
,
hl
,
hh
=
extent
aspect
=
(
hh
-
hl
)
/
float
(
wh
-
wl
)
if
self
.
ax_img
is
None
:
self
.
ax_img
=
ax
.
imshow
(
data
,
aspect
=
aspect
)
self
.
ax_img
.
set_extent
((
xmin
,
xmax
,
ymin
,
ymax
))
self
.
ax
.
autoscale_view
()
else
:
self
.
ax_img
.
set_data
(
data
)
self
.
ax_img
.
set_extent
(
np
.
clip
((
xmin
,
xmax
,
ymin
,
ymax
),
0
,
1
))
xticks
=
np
.
linspace
(
xmin
,
xmax
,
5
)
yticks
=
np
.
linspace
(
ymin
,
ymax
,
5
)
xlabels
=
np
.
round
(
np
.
linspace
(
wl
,
wh
,
5
)).
astype
(
np
.
int
)
ylabels
=
np
.
round
(
np
.
linspace
(
hl
,
hh
,
5
)).
astype
(
np
.
int
)
self
.
ax
.
set_xticks
(
xticks
)
self
.
ax
.
set_xticklabels
(
xlabels
)
self
.
ax
.
set_yticks
(
yticks
)
self
.
ax
.
set_yticklabels
(
ylabels
)
fig
.
canvas
.
draw_idle
()
def
main
():
parser
=
argparse
.
ArgumentParser
(
"Select zone of interest in an image"
)
parser
.
add_argument
(
'-ow'
,
'--overwrite'
,
action
=
'store_true'
,
help
=
'Overwrite output directory if it exists'
)
required
=
parser
.
add_argument_group
(
'Required arguments'
)
required
.
add_argument
(
"-i"
,
"--input"
,
required
=
True
,
help
=
"Input file name of main image including extension."
)
required
.
add_argument
(
"-o"
,
"--output_folder"
,
required
=
True
,
help
=
"Name of output folder."
)
args
=
parser
.
parse_args
()
image_name
=
args
.
input
outfolder
=
args
.
output_folder
# Deal with existing output
overwrite
=
args
.
overwrite
if
op
.
exists
(
outfolder
):
if
not
overwrite
:
print
(
"Folder '{}' exists. Are you sure you want to delete it? [Y,N]"
.
format
(
outfolder
))
response
=
input
()
overwrite
=
response
.
upper
()
==
"Y"
if
not
overwrite
:
print
(
"Ok but some of the existing zones may be overwritten!"
)
else
:
shutil
.
rmtree
(
outfolder
)
os
.
mkdir
(
outfolder
)
else
:
os
.
mkdir
(
outfolder
)
# Read image
print
(
"Reading input Image '{}' ..."
.
format
(
image_name
))
t0
=
time
.
clock
()
im
=
skio
.
imread
(
image_name
)
print
(
" ... this took {} seconds"
.
format
(
time
.
clock
()
-
t0
))
image_base
=
op
.
basename
(
image_name
)
print
(
"Displaying ({} x {} pixels)..."
.
format
(
*
im
.
shape
[:
2
]))
fig
=
plt
.
figure
()
ax
=
fig
.
add_subplot
(
111
)
fig
.
subplots_adjust
(
bottom
=
0.2
)
plt
.
title
(
'Click and drag to draw zones
\n
Close the window when done.'
)
biv
=
BigImageViewer
(
fig
,
ax
,
im
)
zs
=
ZoneSelector
(
ax
,
op
.
join
(
outfolder
,
image_base
),
im
)
plt
.
show
()
if
__name__
==
'__main__'
:
main
()
Click
C
ells/split_zones.py
→
C
ellCounting/c
lick
_c
ells/split_zones.py
View file @
da66c46a
File moved
Click
C
ells/utils.py
→
C
ellCounting/c
lick
_c
ells/utils.py
View file @
da66c46a
File moved
M
odels/.DS_Store
→
CellCounting/m
odels/.DS_Store
View file @
da66c46a
File moved
CellCounting/models/__init__.py
0 → 100644
View file @
da66c46a
CellCounting/models/forward_density.py
0 → 100755
View file @
da66c46a
#!/usr/bin/env python
# Fast Forward Model, Oiwi 2018
# Modified: Saad Jbabdi 09/2018
import
os
import
sys
import
time
import
numpy
as
np
from
skimage
import
io
as
skio
from
keras.models
import
load_model
import
collections
import
itertools
as
it
from
sklearn.utils
import
check_array
from
sklearn.feature_extraction.image
import
extract_patches
from
sklearn.feature_extraction.image
import
_compute_n_patches
import
argparse
def
ffm
(
inimage
,
inmodel
,
basename
=
None
,
stride
=
None
,
gpu
=
True
,
timer
=
True
,
normdir
=
None
):
'''
Fast Forward Model, Oiwi Parker Jones, 2018
'''
if
timer
==
True
:
total_start
=
time
.
time
();
load_start
=
time
.
time
()
print
(
'* Loading data'
)
imnp
=
skio
.
imread
(
inimage
)
if
timer
==
True
:
load_end
=
time
.
time
()
# Load model
model
=
load_model
(
inmodel
)
# get kernel width and height from model
(
k_w
,
k_h
)
=
model
.
layers
[
0
].
input_shape
[
1
:
3
]
# deal with image edges
imnp
=
np
.
pad
(
imnp
,((
k_w
,
k_w
),(
k_h
,
k_h
),(
0
,
0
)),
mode
=
'symmetric'
)
if
stride
==
None
:
print
(
'** Using default stride = {} (i.e. width of kernel)'
.
format
(
k_w
))
stride
=
k_w
if
gpu
==
True
:
from
keras
import
backend
as
K
K
.
tensorflow_backend
.
_get_available_gpus
()
print
(
'* Running forward pass on GPU (CUDA_VISIBLE_DEVICES)'
)
else
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
""
print
(
'* Running forward pass on CPU'
)
# Split imnp into set of sub-image patches
print
(
'* Splitting image into {} x {} patches (stride={})'
.
format
(
k_w
,
k_h
,
stride
))
patches
=
extract_patches_2d_strides
(
imnp
,
(
k_w
,
k_h
),
extraction_step
=
stride
)
print
(
'** {} patches produced'
.
format
(
patches
.
shape
[
0
]))
# get img_avg and img_std to normalise images (from training set)
if
normdir
==
None
:
print
(
"!! normdir==None"
)
print
(
'** Using model without normalisation!!!'
)
img_avg
=
np
.
zeros
((
k_w
,
k_h
,
3
),
dtype
=
'float32'
)
img_std
=
np
.
ones
((
k_w
,
k_h
,
3
),
dtype
=
'float32'
)
else
:
print
(
'** Loading img_avg.npy and img_std.npy for normalisation from: {}'
.
format
(
normdir
))
img_avg
=
np
.
load
(
os
.
path
.
join
(
normdir
,
'img_avg.npy'
))
img_std
=
np
.
load
(
os
.
path
.
join
(
normdir
,
'img_std.npy'
))
if
timer
==
True
:
forward_start
=
time
.
time
()
# Do forward pass on each normalised image patch
patches
=
patches
.
astype
(
np
.
float32
)
#convert type from uint8 (256 bits) to 32 bit float before normalising
patches
=
(
patches
-
img_avg
)
/
img_std
patch_out
=
model
.
predict
(
patches
)
#patch_out_argmax = patch_out.argmax(axis=1) #get predicted class; 1 = 'cell' (so predictions can be averaged)
# Reconstruct predicted output image, for saving
print
(
'* Reconstructing output image'
)
recon
,
_
=
reconstruct_from_patches_2d_strides
(
patches
,
patch_out
,
imnp
.
shape
,
extraction_step
=
stride
)
# remove padding
recon
=
recon
[
k_w
:
-
k_w
,
k_h
:
-
k_h
]
#save pred image recon
if
basename
==
None
:
basename
,
_
=
os
.
path
.
splitext
(
inimage
)
outfile
=
basename
+
'_stride'
+
str
(
stride
)
+
'_density_highres.npz'
#print('* Saving predictions: {}'.format(outfile))
#np.savez_compressed(outfile, recon=recon)
#visualise results
outfile
=
basename
+
'_density.png'
print
(
'* Saving snapshot: {}'
.
format
(
outfile
))
import
matplotlib.pyplot
as
plt
plt
.
imshow
(
imnp
[
k_w
:
-
k_w
,
k_h
:
-
k_h
,:])
plt
.
imshow
(
recon
,
alpha
=
.
2
,
interpolation
=
'bilinear'
)
plt
.
savefig
(
outfile
,
dpi
=
1000
)
if
timer
==
True
:
forward_end
=
time
.
time
()
total_end
=
time
.
time
()
print
(
'** Time to load data = %.2f seconds'
%
(
load_end
-
load_start
))
print
(
'** Time to run model = %.2f seconds'
%
(
forward_end
-
forward_start
))
print
(
'** Total run time = %.2f seconds'
%
(
total_end
-
total_start
))
def
reconstruct_from_patches_2d_strides
(
patches
,
patch_out
,
image_size
,
extraction_step
=
1
):
"""
Wrapper around sklearn's reconstruct_from_patches_2d, but with added strides.
"""
p_h
,
p_w
=
patches
.
shape
[
1
:
3
]
i_h
,
i_w
=
image_size
[:
2
]
s_h
,
s_w
=
[
extraction_step
,
extraction_step
]
img
=
np
.
zeros
((
i_h
,
i_w
),
dtype
=
np
.
float32
)
# progressively back off data type for count, depending on amount of kernel overlap
# data types that use less memory trade off with ability to have smaller strides
count
=
np
.
zeros
((
i_h
,
i_w
),
dtype
=
np
.
uint16
)
# compute the patch indices along each dimension
h_indices
=
np
.
arange
(
0
,
i_h
-
p_h
+
1
,
s_h
)
w_indices
=
np
.
arange
(
0
,
i_w
-
p_w
+
1
,
s_w
)
thr
=
2
for
k
,
(
p
,
(
i
,
j
))
in
enumerate
(
zip
(
patches
,
it
.
product
(
h_indices
,
w_indices
))):
if
(
patch_out
[
k
,
1
]
/
patch_out
[
k
,
0
]
>
thr
):
img
[
i
:
i
+
p_h
,
j
:
j
+
p_w
]
+=
1
count
[
i
:
i
+
p_h
,
j
:
j
+
p_w
]
+=
1
img
=
img
/
(
count
+
(
count
==
0
)).
astype
(
np
.
float32
)
return
img
,
count
def
squarebox
(
h
,
w
,
rgb
):
b
=
np
.
zeros
((
h
,
w
,
3
))
b
[:,
0
,:]
=
rgb
b
[:,
-
1
,:]
=
rgb
b
[
0
,:,:]
=
rgb
b
[
-
1
,:,:]
=
rgb
return
b
# display a nxn grid of images
def
cell_plot
(
X
,
y
,
n
=
10
):
import
matplotlib.pyplot
as
plt
im_size
=
X
.
shape
[
1
]
pad_size
=
5
figure
=
np
.
zeros
((
pad_size
+
(
im_size
+
pad_size
)
*
n
,
pad_size
+
(
im_size
+
pad_size
)
*
n
,
3
),
dtype
=
'uint8'
)
cnt
=
0
ii
=
pad_size
for
i
in
range
(
n
):
jj
=
pad_size
for
j
in
range
(
n
):
if
(
cnt
<
X
.
shape
[
0
]):
figure
[
ii
:
ii
+
im_size
,
jj
:
jj
+
im_size
,
:]
=
X
[
cnt
,...]
if
(
y
[
cnt
,
1
]
>=
y
[
cnt
,
0
]):
rgb
=
[
0
,
255
,
0
]
else
:
rgb
=
[
255
,
0
,
0
]
figure
[
ii
,
jj
:
jj
+
im_size
,:]
=
rgb
figure
[
ii
+
im_size
-
1
,
jj
:
jj
+
im_size
,:]
=
rgb
figure
[
ii
:
ii
+
im_size
,
jj
,
:]
=
rgb
figure
[
ii
:
ii
+
im_size
,
jj
+
im_size
-
1
,
:]
=
rgb
jj
+=
(
im_size
+
pad_size
)
cnt
+=
1
ii
+=
(
im_size
+
pad_size
)
plt
.
figure
(
figsize
=
(
10
,
10
))
plt
.
imshow
(
figure
)
plt
.
show
()
#add extraction_step argument for strides
def
extract_patches_2d_strides
(
image
,
patch_size
,
extraction_step
=
None
):
'''
Wrapper around extract_patches based on sklearn's extract_patches_2d but with added strides.
Oiwi, 2018
'''
if
extraction_step
is
None
:
extraction_step
=
patch_size
#default to nonoverlapping
i_h
,
i_w
=
image
.
shape
[:
2
]
p_h
,
p_w
=
patch_size
if
p_h
>
i_h
:
raise
ValueError
(
"Height of the patch should be less than the height"