Commit bda1c201 authored by Sean Fitzgibbon's avatar Sean Fitzgibbon
Browse files

Improved plotting

parent e0fbbfe3
......@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from operator import index
import os
import os.path as op
import json
from xml.dom import INDEX_SIZE_ERR
import yaml
from scipy.ndimage.interpolation import affine_transform
......@@ -61,26 +63,6 @@ def estimate_field(img):
return 'dark'
class Slide:
contours = None
cells = None
def __init__(self, image, resolution) -> None:
self.data = image
self.resolution = resolution
@property
def shape(self):
return self.data.shape
@classmethod
def from_jp2(cls, fname, resolution):
im = glymur.Jp2k(fname)
return cls(im, resolution)
def register_chart_to_slide(chart, slide, slide_res, outdir, boundary_key=None, config=None, do_plots=None):
......@@ -110,7 +92,8 @@ def register_chart_to_slide(chart, slide, slide_res, outdir, boundary_key=None,
if len(outline)<1: raise ValueError(f'Boundary key {boundary_key} not found in chart')
if len(outline)>1: warnings.warn(f'Repeated entries of boundary key in chart: {boundary_key}')
edge_crds = np.concatenate([x.points[:, :2] for x in outline]) * [1, -1]
edge_crds = [x.points[:, :2]*[1, -1] for x in outline]
edge_crds_cat = np.concatenate(edge_crds)
# load slide, convert to grayscale, and invert if light-field
......@@ -128,83 +111,127 @@ def register_chart_to_slide(chart, slide, slide_res, outdir, boundary_key=None,
img = enhance(img)
# initial scaling based on boundng boxes
# initial alignment based on boundng boxes
init_xfm = init_scale(img, slide_res, edge_crds)
init_xfm, img_props, coord_props = init_scale(img, slide_res, edge_crds_cat)
print(init_xfm)
# print(init_xfm)
tr_x, tr_y = init_xfm.translation
print(
f"Rotation: {init_xfm.rotation}, Translation: {init_xfm.translation}, Scale: {init_xfm.scale}"
f"Initial XFM - Rotation: {init_xfm.rotation:0.5f}, Translation: [{tr_x:0.5f} {tr_y:0.5f}], Scale: {init_xfm.scale:0.5f}"
)
# calculate normal line to boundary points
xfm_edge_coords = apply_xfm(init_xfm, edge_crds)
init_nrmls = normal(xfm_edge_coords)
np.savetxt(f"{outdir}/chart-to-image-init.xfm", init_xfm.params)
# refine alignmnet (to mask edges)
# refine edge_coords (to image)
refined_edge_coords, exclude_mask = refine_edge_coord(
img, slide_res, xfm_edge_coords, init_nrmls
opt_xfm = refine_edge_coord(
img, slide_res, edge_crds_cat, init_xfm
)
# estimate opimised affine transform
opt_xfm = transform.SimilarityTransform()
opt_xfm.estimate(edge_crds[exclude_mask, :], refined_edge_coords[exclude_mask, :])
print(opt_xfm)
# print(opt_xfm)
tr_x, tr_y = opt_xfm.translation
print(
f"Rotation: {opt_xfm.rotation}, Translation: {opt_xfm.translation}, Scale: {opt_xfm.scale}"
f"Optimised XFM - Rotation: {opt_xfm.rotation:0.5f}, Translation: [{tr_x:0.5f} {tr_y:0.5f}], Scale: {opt_xfm.scale:0.5f}"
)
# # save opt transform
np.savetxt(f"{outdir}/chart-to-image.xfm", opt_xfm.params)
# apply opt-xfm to contours and cells and save
# apply optimised xfm to contours and cells and save
contour_xfm = [
(contour.name, apply_xfm(opt_xfm, contour.points[:, :2] * [1, -1]).tolist(), contour.closed) for contour in chart.contours
]
with open(f"{outdir}/contour.json", "w") as fp:
json.dump(contour_xfm, fp, indent=4)
if chart.n_cells > 0:
cells = np.concatenate([cell.point[:2][np.newaxis, :] for cell in chart.cells]) * [1, -1]
cells_xfm = apply_xfm(init_xfm, cells )
with open(f"{outdir}/cells.json", "w") as fp:
json.dump(cells_xfm.tolist(), fp, indent=4)
# do plots
if do_plots:
fig, ax = plt.subplots(figsize=(20, 30))
extent = np.array([0, img.shape[1], img.shape[0], 0]) * slide_res
ax.imshow(img, extent=extent, cmap='gray')
fig, ax = plt.subplots(2, 2, figsize=(40, 40))
ax = np.ravel(ax)
# chart bounding box
plot_contour(ax[0], edge_crds_cat, title="Boundary contour + bounding box", linestyle='none', marker='.', color='b')
plot_box(ax[0], coord_props['bbox'])
ax[0].axis("equal")
# image bounding box
plot_slide(ax[1], img_props['mask'], slide_res, title="Slide mask + bounding box")
plot_box(ax[1], img_props['bbox'])
# boundary plots
for name, coords, closed in contour_xfm:
plot_slide(ax[2], img, slide_res, title='Boundary alignment')
for contour in edge_crds:
plot_contour(ax[2], apply_xfm(init_xfm, contour), color=(1, 0, 0), marker='.')
plot_contour(ax[2], apply_xfm(opt_xfm, contour), color=(0, 1, 0), marker='.')
ax[2].legend(["boundary_init", "boundary_optimised"])
# aligned chart
plot_slide(ax[3], img, slide_res, title="Aligned Chart")
cmap = plt.get_cmap("tab10")
for idx, (name, coords, closed) in enumerate(contour_xfm):
if not closed: continue
plot_contour(ax[3], coords, name, color=cmap(idx))
coords = np.array(coords)
fig.savefig(f'{outdir}/alignment.png', bbox_inches='tight', dpi=300)
cog = np.mean(coords, axis=0)
ax.text(cog[0], cog[1], name, color='r')
ax.plot(coords[:, 0], coords[:, 1], "r-")
def plot_box(ax, bbox, edgecolor='r', facecolor='none', linewidth=1):
rect = patches.Rectangle(
(bbox[1], bbox[0]),
bbox[3] - bbox[1],
bbox[2] - bbox[0],
linewidth=linewidth,
edgecolor=edgecolor,
facecolor=facecolor,
)
ax.add_patch(rect)
ax.set_xlabel("mm")
ax.set_ylabel("mm")
ax.set_title("Aligned Chart")
fig.savefig(f"{outdir}/aligned_chart.png", bbox_inches="tight", dpi=300)
def plot_slide(ax, img, slide_res, title=None):
# fig, ax = plt.subplots(1, 5, figsize=(25, 5))
# ax[0].imshow(plt.imread(f'{OUTDIR}/chart_bounding_box.png'))
# ax[1].imshow(plt.imread(f'{OUTDIR}/image_bounding_box.png'))
# ax[2].imshow(plt.imread(f'{OUTDIR}/normals.png'))
# ax[3].imshow(plt.imread(f'{OUTDIR}/refined_coords.png'))
# ax[4].imshow(plt.imread(f'{OUTDIR}/aligned_chart.png'))
extent = np.array([0, img.shape[1], img.shape[0], 0]) * slide_res
ax.imshow(img, extent=extent, cmap='gray')
# for a in ax:
# a.axis('off')
ax.set_xlabel("mm")
ax.set_ylabel("mm")
# fig.savefig(f'{OUTDIR}/alignment.png', bbox_inches='tight', dpi=150)
if title is not None:
ax.set_title(title)
with open(f"{outdir}/contour.json", "w") as fp:
json.dump(contour_xfm, fp, indent=4)
if chart.n_cells > 0:
cells = np.concatenate([cell.point[:2][np.newaxis, :] for cell in chart.cells]) * [1, -1]
cells_xfm = apply_xfm(init_xfm, cells )
with open(f"{outdir}/cells.json", "w") as fp:
json.dump(cells_xfm.tolist(), fp, indent=4)
def plot_contour(ax, coords, name=None, color='r', title=None, linewidth=1, **kwargs):
coords = np.array(coords)
ax.plot(coords[:, 0], coords[:, 1], color=color, linewidth=linewidth, **kwargs)
if name is not None:
cog = np.mean(coords, axis=0)
ax.text(cog[0], cog[1], name, color=color)
if title is not None:
ax.set_title(title)
ax.invert_yaxis()
def enhance(img0, kernel_size=None, lower_percentile=2, upper_percentile=98, sigma=5):
......@@ -253,41 +280,45 @@ def segment_foreground(img, marker_threshold=(0.02, 0.2), min_component_size=100
return brainmask
def refine_edge_coord(img, img_res, edge_coords, normals, do_plots=True):
def refine_edge_coord(img, img_res, edge_coords, xfm_init):
"""
Refine edge_coord by sampling image along normal (to edge) and looking for big step change.
"""
# calculate normal line to boundary points
edge_coords_init = apply_xfm(xfm_init, edge_coords)
normals = normal(edge_coords_init)
# calculate normal line (to edge_coords)
edge_x, edge_y = edge_coords.T
edge_init_x, edge_init_y = edge_coords_init.T
nrml_x, nrml_y = normals.T
# TODO: move these line extents to the config file
# line_smpl = np.linspace(-0.03, 0.15, 20)
line_smpl = np.linspace(-0.2, 0.2, 20)
line_x = edge_x[:, np.newaxis] + nrml_x[:, np.newaxis] * line_smpl
line_y = edge_y[:, np.newaxis] + nrml_y[:, np.newaxis] * line_smpl
line_x = edge_init_x[:, np.newaxis] + nrml_x[:, np.newaxis] * line_smpl
line_y = edge_init_y[:, np.newaxis] + nrml_y[:, np.newaxis] * line_smpl
brainmask = segment_foreground(img)
if do_plots:
fig, ax = plt.subplots(figsize=(20, 30))
# if do_plots:
# fig, ax = plt.subplots(figsize=(20, 30))
extent = np.array([0, img.shape[1], img.shape[0], 0]) * img_res
ax.imshow(brainmask, extent=extent, cmap='gray')
# extent = np.array([0, img.shape[1], img.shape[0], 0]) * img_res
# ax.imshow(brainmask, extent=extent, cmap='gray')
for line_x0, line_y0 in zip(line_x, line_y):
ax.plot(line_x0, line_y0, "b.")
# for line_x0, line_y0 in zip(line_x, line_y):
# ax.plot(line_x0, line_y0, "b.")
ax.plot(edge_x, edge_y, "r.")
# ax.plot(edge_init_x, edge_init_y, "r.")
ax.set_xlabel("mm")
ax.set_ylabel("mm")
ax.set_title("Brainmask + normals")
# ax.set_xlabel("mm")
# ax.set_ylabel("mm")
# ax.set_title("Brainmask + normals")
plt.show()
# fig.savefig(f"{OUTDIR}/normals.png", bbox_inches="tight", dpi=300)
# plt.show()
# # fig.savefig(f"{OUTDIR}/normals.png", bbox_inches="tight", dpi=300)
# sample image along normal line
......@@ -310,29 +341,13 @@ def refine_edge_coord(img, img_res, edge_coords, normals, do_plots=True):
axis=-1,
)
# TODO: plot edge_coords + refined_edge_coords
if do_plots:
fig, ax = plt.subplots(figsize=(20, 30))
extent = np.array([0, img.shape[1], img.shape[0], 0]) * img_res
ax.imshow(brainmask, extent=extent, cmap='gray')
ax.plot(edge_x, edge_y, "r.-")
ax.plot(refined_edge_coords[~constant_idx, 0], refined_edge_coords[~constant_idx, 1], "g.-")
ax.set_xlabel("mm")
ax.set_ylabel("mm")
# ax.set_title('edge_co')
ax.legend(["edge_coords", "refined_edge_coords"])
plt.show()
# fig.savefig(f"{OUTDIR}/refined_coords.png", bbox_inches="tight", dpi=150)
opt_xfm = transform.SimilarityTransform()
opt_xfm.estimate(edge_coords[~constant_idx, :], refined_edge_coords[~constant_idx, :])
return refined_edge_coords, ~constant_idx
return opt_xfm
def image_props(img, img_resolution, do_plots=False, plot_name=None):
def image_props(img, img_resolution):
brainmask = segment_foreground(img)
......@@ -350,36 +365,6 @@ def image_props(img, img_resolution, do_plots=False, plot_name=None):
(bbox[1] + bbox[3]) / 2,
)
if do_plots:
fig, ax = plt.subplots()
extent = np.array([0, img.shape[1], img.shape[0], 0]) * img_resolution
im = ax.imshow(brainmask, extent=extent)
fig.colorbar(im)
rect = patches.Rectangle(
(bbox[1], bbox[0]),
bbox[3] - bbox[1],
bbox[2] - bbox[0],
linewidth=1,
edgecolor="r",
facecolor="none",
)
ax.add_patch(rect)
ax.plot(centroid[1], centroid[0], "ro")
ax.set_xlabel("mm")
ax.set_ylabel("mm")
ax.set_title("Image bounding box")
if plot_name is None:
plt.show()
else:
# image_bounding_box.png
fig.savefig(plot_name, bbox_inches="tight", dpi=150)
return {
"bbox": bbox,
"bbox_centroid": centroid,
......@@ -389,7 +374,7 @@ def image_props(img, img_resolution, do_plots=False, plot_name=None):
}
def point_props(pnts, do_plots=False, plot_name=None):
def point_props(pnts):
x = pnts[:, 0]
y = pnts[:, 1]
......@@ -401,32 +386,6 @@ def point_props(pnts, do_plots=False, plot_name=None):
(bbox[1] + bbox[3]) / 2,
)
if do_plots:
fig, ax = plt.subplots()
ax.plot(x, y, "b.")
ax.axis("equal")
ax.invert_yaxis()
rect = patches.Rectangle(
(bbox[1], bbox[0]),
bbox[3] - bbox[1],
bbox[2] - bbox[0],
linewidth=1,
edgecolor="r",
facecolor="none",
)
ax.add_patch(rect)
plt.plot(centroid[1], centroid[0], "ro")
ax.set_title("Chart bounding box")
if plot_name is None:
plt.show()
else:
# chart_bounding_box.png
fig.savefig(plot_name, bbox_inches="tight", dpi=150)
return {
"bbox": bbox,
"bbox_centroid": centroid,
......@@ -438,11 +397,11 @@ def point_props(pnts, do_plots=False, plot_name=None):
def init_scale(img, img_resolution, crd, tol=0.05):
img_p = image_props(img, img_resolution, do_plots=True)
crd_p = point_props(crd, do_plots=True)
img_p = image_props(img, img_resolution)
crd_p = point_props(crd)
print(img_p)
print(crd_p)
# print(img_p)
# print(crd_p)
# justify='right'
......@@ -482,7 +441,7 @@ def init_scale(img, img_resolution, crd, tol=0.05):
a = transform.SimilarityTransform()
a.estimate(src, dest)
return a
return a, img_p, crd_p
def apply_xfm(xfm, pnts):
......
......@@ -20,6 +20,7 @@ from dataclasses import dataclass
from typing import Optional, List
import numpy as np
from slider.external import neurolucida
import glymur
def get_slider_dir() -> str:
rpath = op.realpath(op.dirname(op.abspath(__file__)))
......@@ -91,4 +92,22 @@ class Chart:
)
def __repr__(self):
return f"Chart(contours=[{self.n_contours}x ChartContour], cells=[{self.n_cells}x ChartCell])"
\ No newline at end of file
return f"Chart(contours=[{self.n_contours}x ChartContour], cells=[{self.n_cells}x ChartCell])"
@dataclass
class Slide:
data: np.array
resolution: float
mask: Optional[np.array] = None
@property
def shape(self):
return self.data.shape
@classmethod
def from_jp2(cls, fname, resolution):
im = glymur.Jp2k(fname)
return cls(im, resolution)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment