Commit 78dc0c0b authored by inhuszar's avatar inhuszar
Browse files

MIND: avoiding sticky image borders by suppressing SSD

parent 8c09117a
...@@ -190,20 +190,46 @@ class CostMIND(CostMSD): ...@@ -190,20 +190,46 @@ class CostMIND(CostMSD):
# Restore the axis order # Restore the axis order
out.order = img.order out.order = img.order
vc = img.domain.get_voxel_coordinates()
if img.mask is not None:
masked = np.flatnonzero(img.mask == 0)
else:
masked = []
# Calculate patch-distance of the central voxel from each neighbour # Calculate patch-distance of the central voxel from each neighbour
for i, vector in enumerate(shifts): for i, vector in enumerate(shifts):
# Shift the image such that the central voxel # Shift the image such that the central voxel
# overlaps with the specific neighbour # overlaps with the specific neighbour
img_shifted = img.copy() img_shifted = np.empty_like(img.data)
for axis, vi in zip(img.vaxes, vector): for axis, vi in zip(img.vaxes, vector):
if vi: if vi:
img_shifted.data[...] = np.roll(img_shifted.data, vi, axis) img_shifted[...] = np.roll(img.data, vi, axis)
# Calculate voxelwise squared differences over the entire image # Calculate voxelwise squared differences over the entire image
ssd = (img_shifted.data - img.data) ** 2 ssd = (img_shifted - img.data) ** 2
del img_shifted del img_shifted
# Zero the difference on the edges (Neumann)
# This is to prevent "sticky edges"
test_coordinates = vc - vector
bindices = [np.flatnonzero(np.any(test_coordinates < 0, axis=-1))]
for dim in range(img.vdim):
ix = np.flatnonzero(test_coordinates[:, dim] >= img.vshape[dim])
bindices.append(ix)
else:
bindices = np.unique(np.concatenate(bindices))
ssd.flat[bindices] = 0
# Treat mask edges the same as image borders
if self.metaparameters.get("ignore_masked_edges"):
tc = np.delete(test_coordinates, bindices, axis=0)
vcc = np.delete(vc, bindices, axis=0)
mindices = np.ravel_multi_index(tc.T, img.vshape)
mindices = np.ravel_multi_index(
vcc[np.in1d(mindices, masked)].T, img.vshape)
ssd.flat[mindices] = 0
# Convert distances to patch-wise distance # Convert distances to patch-wise distance
# This step ensures robust estimation of the distance # This step ensures robust estimation of the distance
# from the neighbour. Do not smooth across tensor values. # from the neighbour. Do not smooth across tensor values.
...@@ -219,6 +245,7 @@ class CostMIND(CostMSD): ...@@ -219,6 +245,7 @@ class CostMIND(CostMSD):
low, high = np.asarray([1e-3, 1e3]) * np.mean(variance) # robustness low, high = np.asarray([1e-3, 1e3]) * np.mean(variance) # robustness
variance = np.clip(variance, low, high) variance = np.clip(variance, low, high)
tmp = np.exp(-tmp / variance) tmp = np.exp(-tmp / variance)
del variance del variance
out.data[...] = tmp / np.max(tmp, 0, keepdims=True) out.data[...] = tmp / np.max(tmp, 0, keepdims=True)
# out.data[~np.isfinite(out.data)] = 0 # out.data[~np.isfinite(out.data)] = 0
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment