diff --git a/fsl/data/gifti.py b/fsl/data/gifti.py index c27054ab8ba3a7be26f2f98d54b78b9609c07b0d..7bc088d1735cce7c02309b4be66e30d10cc5b86a 100644 --- a/fsl/data/gifti.py +++ b/fsl/data/gifti.py @@ -49,7 +49,7 @@ class GiftiSurface(mesh.TriangleMesh): """ - def __init__(self, infile): + def __init__(self, infile, fixWinding=False): """Load the given GIFTI file using ``nibabel``, and extracts surface data using the :func:`loadGiftiSurface` function. @@ -61,7 +61,7 @@ class GiftiSurface(mesh.TriangleMesh): surfimg, vertices, indices = loadGiftiSurface(infile) - mesh.TriangleMesh.__init__(self, vertices, indices) + mesh.TriangleMesh.__init__(self, vertices, indices, fixWinding) name = fslpath.removeExt(op.basename(infile), ALLOWED_EXTENSIONS) infile = op.abspath(infile) diff --git a/fsl/data/mesh.py b/fsl/data/mesh.py index 82023debde2dc67f7d64bbdc6547c09b92b42955..c3c701232406c37d3bdd0ddff8e54cd63d3fab28 100644 --- a/fsl/data/mesh.py +++ b/fsl/data/mesh.py @@ -27,6 +27,8 @@ import numpy as np import six +import fsl.utils.transform as transform + from . import image as fslimage @@ -34,9 +36,9 @@ log = logging.getLogger(__name__) class TriangleMesh(object): - """The ``TriangleMesh`` class represents a 3D model. A mesh is defined by - a collection of vertices and indices. The indices index into the list of - vertices, and define a set of triangles which make the model. + """The ``TriangleMesh`` class represents a 3D model. A mesh is defined by a + collection of ``N`` vertices, and ``M`` triangles. The triangles are + defined by ``(M, 3)) indices into the list of vertices. A ``TriangleMesh`` instance has the following attributes: @@ -74,16 +76,20 @@ class TriangleMesh(object): """ - def __init__(self, data, indices=None): + def __init__(self, data, indices=None, fixWinding=False): """Create a ``TriangleMesh`` instance. - :arg data: Can either be a file name, or a :math:`N\\times 3` - ``numpy`` array containing vertex data. If ``data`` is - a file name, it is passed to the - :func:`loadVTKPolydataFile` function. + :arg data: Can either be a file name, or a :math:`N\\times 3` + ``numpy`` array containing vertex data. If ``data`` + is a file name, it is passed to the + :func:`loadVTKPolydataFile` function. + + :arg indices: A list of indices into the vertex data, defining + the triangles. - :arg indices: A list of indices into the vertex data, defining - the triangles. + :arg fixWinding: Defaults to ``False``. If ``True``, the vertex + winding order of every triangle is is fixed so they + all have outward-facing normal vectors. """ if isinstance(data, six.string_types): @@ -103,8 +109,8 @@ class TriangleMesh(object): if indices is None: indices = np.arange(data.shape[0]) - self.vertices = np.array(data) - self.indices = np.array(indices).reshape((-1, 3)) + self.__vertices = np.array(data) + self.__indices = np.array(indices).reshape((-1, 3)) self.__vertexData = {} self.__faceNormals = None @@ -112,6 +118,9 @@ class TriangleMesh(object): self.__loBounds = self.vertices.min(axis=0) self.__hiBounds = self.vertices.max(axis=0) + if fixWinding: + self.__fixWindingOrder() + def __repr__(self): """Returns a string representation of this ``TriangleMesh`` instance. @@ -127,35 +136,73 @@ class TriangleMesh(object): @property - def normals(self): - """Returns a ``(M, 3)`` array containing normals for every triangle - in the mesh. + def vertices(self): + """The ``(N, 3)`` vertices of this mesh. """ + return self.__vertices + + + @property + def indices(self): + """The ``(M, 3)`` triangles of this mesh. """ + return self.__indices + + + def __fixWindingOrder(self): + """Called by :meth:`__init__` if ``fixWinding is True``. Fixes the + mesh triangle winding order so that all face normals are facing + outwards from the centre of the mesh. """ - if self.__faceNormals is not None: - return self.__faceNormals - v1 = self.vertices[self.indices[:, 0]] - v2 = self.vertices[self.indices[:, 1]] + # Define a viewpoint which is + # far away from the mesh. + fnormals = self.normals + camera = self.__loBounds - (self.__hiBounds - self.__loBounds) + + # Find the nearest vertex + # to the viewpoint + dists = np.sqrt(np.sum((self.vertices - camera) ** 2, axis=1)) + ivert = np.argmin(dists) + vert = self.vertices[ivert] + + # Pick a triangle that + # this vertex in and + # ges its face normal + itri = np.where(self.indices == ivert)[0][0] + n = fnormals[itri, :] + + # Make sure the angle between the + # normal, and a vector from the + # vertex to the camera is positive + # If it isn't, flip the triangle + # winding order. + if np.dot(n, transform.normalise(camera - vert)) < 0: + self.indices[:, [1, 2]] = self.indices[:, [2, 1]] + self.__faceNormals *= -1 - fnormals = np.cross(v1, v2) - # TODO make fast, and make sure that this actually works. - for i in range(self.indices.shape[0]): + @property + def normals(self): + """A ``(M, 3)`` array containing surface normals for every + triangle in the mesh, normalised to unit length. + """ - p0, p1, p2 = self.vertices[self.indices[i], :] - n = fnormals[i, :] + if self.__faceNormals is not None: + return self.__faceNormals + + v0 = self.vertices[self.indices[:, 0]] + v1 = self.vertices[self.indices[:, 1]] + v2 = self.vertices[self.indices[:, 2]] - if np.dot(n, np.cross(p1 - p0, p2 - p0)) > 0: - fnormals[i, :] *= -1 + n = np.cross((v1 - v0), (v2 - v0)) - self.__faceNormals = fnormals + self.__faceNormals = transform.normalise(n) return self.__faceNormals @property def vnormals(self): - """Returns a ``(N, 3)`` array containing normals for every vertex + """A ``(N, 3)`` array containing normals for every vertex in the mesh. """ if self.__vertNormals is not None: @@ -163,10 +210,12 @@ class TriangleMesh(object): # per-face normals fnormals = self.normals - vnormals = np.zeros((self.vertices.shape[0], 3), dtype=np.float) - # TODO make fast + # TODO make fast. I can't figure + # out how to use np.add.at to + # accumulate the face normals for + # each vertex. for i in range(self.indices.shape[0]): v0, v1, v2 = self.indices[i] @@ -176,10 +225,7 @@ class TriangleMesh(object): vnormals[v2, :] += fnormals[i] # normalise to unit length - lens = np.sqrt(np.sum(vnormals * vnormals, axis=1)) - vnormals = (vnormals.T / lens).T - - self.__vertNormals = vnormals + self.__vertNormals = transform.normalise(vnormals) return self.__vertNormals