Skip to content
Snippets Groups Projects
Commit 3c0a0d7a authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

Unit test for face/vertex normal, and fix winding logic in TriangleMesh class.

parent 949c92a8
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,8 @@ import tempfile ...@@ -13,7 +13,8 @@ import tempfile
import numpy as np import numpy as np
import pytest import pytest
import fsl.data.mesh as fslmesh import fsl.utils.transform as transform
import fsl.data.mesh as fslmesh
datadir = op.join(op.dirname(__file__), 'testdata') datadir = op.join(op.dirname(__file__), 'testdata')
...@@ -26,8 +27,8 @@ def test_create_mesh(): ...@@ -26,8 +27,8 @@ def test_create_mesh():
# - create from inmem data # - create from inmem data
testbase = 'test_mesh.vtk' testbase = 'test_mesh.vtk'
testfile = op.join(datadir, testbase) testfile = op.join(datadir, testbase)
verts, lens, indices = fslmesh.loadVTKPolydataFile(testfile) verts, lens, indices = fslmesh.loadVTKPolydataFile(testfile)
mesh1 = fslmesh.TriangleMesh(testfile) mesh1 = fslmesh.TriangleMesh(testfile)
mesh2 = fslmesh.TriangleMesh(verts, indices) mesh2 = fslmesh.TriangleMesh(verts, indices)
...@@ -36,9 +37,9 @@ def test_create_mesh(): ...@@ -36,9 +37,9 @@ def test_create_mesh():
assert mesh2.name == 'TriangleMesh' assert mesh2.name == 'TriangleMesh'
assert mesh1.dataSource == testfile assert mesh1.dataSource == testfile
assert mesh2.dataSource is None assert mesh2.dataSource is None
assert mesh1.vertices.shape == (642, 3) assert mesh1.vertices.shape == (642, 3)
assert mesh2.vertices.shape == (642, 3) assert mesh2.vertices.shape == (642, 3)
assert mesh1.indices.shape == (1280, 3) assert mesh1.indices.shape == (1280, 3)
assert mesh2.indices.shape == (1280, 3) assert mesh2.indices.shape == (1280, 3)
...@@ -68,7 +69,7 @@ def test_mesh_loadVertexData(): ...@@ -68,7 +69,7 @@ def test_mesh_loadVertexData():
assert np.all(mesh.getVertexData('inmemdata') == memdata) assert np.all(mesh.getVertexData('inmemdata') == memdata)
mesh.clearVertexData() mesh.clearVertexData()
assert mesh.getVertexData(datafile).shape == (642,) assert mesh.getVertexData(datafile).shape == (642,)
assert np.all(mesh.loadVertexData('inmemdata', memdata) == memdata) assert np.all(mesh.loadVertexData('inmemdata', memdata) == memdata)
...@@ -103,7 +104,7 @@ def test_getFIRSTPrefix(): ...@@ -103,7 +104,7 @@ def test_getFIRSTPrefix():
for fname, expected in passes: for fname, expected in passes:
assert fslmesh.getFIRSTPrefix(fname) == expected assert fslmesh.getFIRSTPrefix(fname) == expected
def test_findReferenceImage(): def test_findReferenceImage():
...@@ -123,11 +124,75 @@ def test_findReferenceImage(): ...@@ -123,11 +124,75 @@ def test_findReferenceImage():
prefix = fslmesh.getFIRSTPrefix(fname) prefix = fslmesh.getFIRSTPrefix(fname)
imgfname = op.join(testdir, '{}.nii.gz'.format(prefix)) imgfname = op.join(testdir, '{}.nii.gz'.format(prefix))
fname = op.join(testdir, fname) fname = op.join(testdir, fname)
with open(fname, 'wt') as f: f.write(fname) with open(fname, 'wt') as f: f.write(fname)
with open(imgfname, 'wt') as f: f.write(imgfname) with open(imgfname, 'wt') as f: f.write(imgfname)
assert fslmesh.findReferenceImage(fname) == imgfname assert fslmesh.findReferenceImage(fname) == imgfname
finally: finally:
shutil.rmtree(testdir) shutil.rmtree(testdir)
def test_normals():
# vertices of a cube
verts = np.array([
[-1, -1, -1],
[-1, -1, 1],
[-1, 1, -1],
[-1, 1, 1],
[ 1, -1, -1],
[ 1, -1, 1],
[ 1, 1, -1],
[ 1, 1, 1],
])
# triangles
# cw == clockwise, when facing outwards
# from the centre of the mesh
triangles_cw = np.array([
[0, 4, 6], [0, 6, 2],
[1, 3, 5], [3, 7, 5],
[0, 1, 4], [1, 5, 4],
[2, 6, 7], [2, 7, 3],
[0, 2, 1], [1, 2, 3],
[4, 5, 7], [4, 7, 6],
])
# ccw == counter-clockwise
triangles_ccw = np.array(triangles_cw)
triangles_ccw[:, [1, 2]] = triangles_ccw[:, [2, 1]]
# face normals
fnormals = np.array([
[ 0, 0, -1], [ 0, 0, -1],
[ 0, 0, 1], [ 0, 0, 1],
[ 0, -1, 0], [ 0, -1, 0],
[ 0, 1, 0], [ 0, 1, 0],
[-1, 0, 0], [-1, 0, 0],
[ 1, 0, 0], [ 1, 0, 0],
])
# vertex normals
vnormals = np.zeros((8, 3))
for i in range(8):
faces = np.where(triangles_cw == i)[0]
vnormals[i] = fnormals[faces].sum(axis=0)
vnormals = transform.normalise(vnormals)
cw_nofix = fslmesh.TriangleMesh(verts, triangles_cw)
cw_fix = fslmesh.TriangleMesh(verts, triangles_cw, fixWinding=True)
ccw_nofix = fslmesh.TriangleMesh(verts, triangles_ccw)
ccw_fix = fslmesh.TriangleMesh(verts, triangles_ccw, fixWinding=True)
# ccw triangles should give correct
# normals without unwinding
assert np.all(np.isclose(cw_nofix .normals, -fnormals))
assert np.all(np.isclose(cw_nofix .vnormals, -vnormals))
assert np.all(np.isclose(cw_fix .normals, fnormals))
assert np.all(np.isclose(cw_fix .vnormals, vnormals))
assert np.all(np.isclose(ccw_nofix.normals, fnormals))
assert np.all(np.isclose(ccw_nofix.vnormals, vnormals))
assert np.all(np.isclose(ccw_fix .normals, fnormals))
assert np.all(np.isclose(ccw_fix .vnormals, vnormals))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment