From 3c0a0d7ade365f46e3a3d765ac235806981f317c Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Wed, 8 Nov 2017 10:24:03 +0000
Subject: [PATCH] Unit test for face/vertex normal, and fix winding logic in
 TriangleMesh class.

---
 tests/test_mesh.py | 83 +++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 74 insertions(+), 9 deletions(-)

diff --git a/tests/test_mesh.py b/tests/test_mesh.py
index c2d234fb1..289f48a7e 100644
--- a/tests/test_mesh.py
+++ b/tests/test_mesh.py
@@ -13,7 +13,8 @@ import            tempfile
 import numpy   as np
 import            pytest
 
-import fsl.data.mesh as fslmesh
+import fsl.utils.transform as transform
+import fsl.data.mesh       as fslmesh
 
 
 datadir = op.join(op.dirname(__file__), 'testdata')
@@ -26,8 +27,8 @@ def test_create_mesh():
     #  - create from inmem data
     testbase = 'test_mesh.vtk'
     testfile = op.join(datadir, testbase)
-    
-    verts, lens, indices = fslmesh.loadVTKPolydataFile(testfile) 
+
+    verts, lens, indices = fslmesh.loadVTKPolydataFile(testfile)
 
     mesh1 = fslmesh.TriangleMesh(testfile)
     mesh2 = fslmesh.TriangleMesh(verts, indices)
@@ -36,9 +37,9 @@ def test_create_mesh():
     assert mesh2.name       == 'TriangleMesh'
     assert mesh1.dataSource == testfile
     assert mesh2.dataSource is None
-    
+
     assert mesh1.vertices.shape == (642,  3)
-    assert mesh2.vertices.shape == (642,  3) 
+    assert mesh2.vertices.shape == (642,  3)
     assert mesh1.indices.shape  == (1280, 3)
     assert mesh2.indices.shape  == (1280, 3)
 
@@ -68,7 +69,7 @@ def test_mesh_loadVertexData():
     assert np.all(mesh.getVertexData('inmemdata') == memdata)
 
     mesh.clearVertexData()
-    
+
     assert mesh.getVertexData(datafile).shape == (642,)
     assert np.all(mesh.loadVertexData('inmemdata', memdata) == memdata)
 
@@ -103,7 +104,7 @@ def test_getFIRSTPrefix():
 
     for fname, expected in passes:
         assert fslmesh.getFIRSTPrefix(fname) == expected
-            
+
 
 
 def test_findReferenceImage():
@@ -123,11 +124,75 @@ def test_findReferenceImage():
             prefix   = fslmesh.getFIRSTPrefix(fname)
             imgfname = op.join(testdir, '{}.nii.gz'.format(prefix))
             fname    = op.join(testdir, fname)
-                
+
             with open(fname,    'wt') as f: f.write(fname)
             with open(imgfname, 'wt') as f: f.write(imgfname)
 
             assert fslmesh.findReferenceImage(fname) == imgfname
-    
+
     finally:
         shutil.rmtree(testdir)
+
+
+def test_normals():
+
+    # vertices of a cube
+    verts = np.array([
+        [-1, -1, -1],
+        [-1, -1,  1],
+        [-1,  1, -1],
+        [-1,  1,  1],
+        [ 1, -1, -1],
+        [ 1, -1,  1],
+        [ 1,  1, -1],
+        [ 1,  1,  1],
+    ])
+
+    # triangles
+    # cw  == clockwise, when facing outwards
+    #        from the centre of the mesh
+    triangles_cw = np.array([
+        [0, 4, 6], [0, 6, 2],
+        [1, 3, 5], [3, 7, 5],
+        [0, 1, 4], [1, 5, 4],
+        [2, 6, 7], [2, 7, 3],
+        [0, 2, 1], [1, 2, 3],
+        [4, 5, 7], [4, 7, 6],
+    ])
+
+    # ccw == counter-clockwise
+    triangles_ccw = np.array(triangles_cw)
+    triangles_ccw[:, [1, 2]] = triangles_ccw[:, [2, 1]]
+
+    # face normals
+    fnormals = np.array([
+        [ 0,  0, -1], [ 0,  0, -1],
+        [ 0,  0,  1], [ 0,  0,  1],
+        [ 0, -1,  0], [ 0, -1,  0],
+        [ 0,  1,  0], [ 0,  1,  0],
+        [-1,  0,  0], [-1,  0,  0],
+        [ 1,  0,  0], [ 1,  0,  0],
+    ])
+
+    # vertex normals
+    vnormals = np.zeros((8, 3))
+    for i in range(8):
+        faces = np.where(triangles_cw == i)[0]
+        vnormals[i] = fnormals[faces].sum(axis=0)
+    vnormals = transform.normalise(vnormals)
+
+    cw_nofix  = fslmesh.TriangleMesh(verts, triangles_cw)
+    cw_fix    = fslmesh.TriangleMesh(verts, triangles_cw, fixWinding=True)
+    ccw_nofix = fslmesh.TriangleMesh(verts, triangles_ccw)
+    ccw_fix   = fslmesh.TriangleMesh(verts, triangles_ccw, fixWinding=True)
+
+    # ccw triangles should give correct
+    # normals without unwinding
+    assert np.all(np.isclose(cw_nofix .normals,  -fnormals))
+    assert np.all(np.isclose(cw_nofix .vnormals, -vnormals))
+    assert np.all(np.isclose(cw_fix   .normals,   fnormals))
+    assert np.all(np.isclose(cw_fix   .vnormals,  vnormals))
+    assert np.all(np.isclose(ccw_nofix.normals,   fnormals))
+    assert np.all(np.isclose(ccw_nofix.vnormals, vnormals))
+    assert np.all(np.isclose(ccw_fix  .normals,   fnormals))
+    assert np.all(np.isclose(ccw_fix  .vnormals,  vnormals))
-- 
GitLab