mesh.py 12 KB
Newer Older
1
2
#!/usr/bin/env python
#
3
# mesh.py - The TriangleMesh class.
4
5
6
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
7
8
"""This module provides the :class:`TriangleMesh` class, which represents a
3D model made of triangles.
9

10
.. note:: I/O support is very limited - currently, the only supported file
11
          type is the VTK legacy file format, containing the ``POLYDATA``
12
13
14
          dataset. the :class:`TriangleMesh` class assumes that every polygon
          defined in an input file is a triangle (i.e. refers to three
          vertices).
15
16
17

          See http://www.vtk.org/wp-content/uploads/2015/04/file-formats.pdf
          for an overview of the VTK legacy file format.
18
19

          In the future, I may or may not add support for more complex meshes.
20
21
"""

22

23
24
import logging

25
26
import os.path as op
import numpy   as np
27

Paul McCarthy's avatar
Paul McCarthy committed
28
29
import six

30
31
import fsl.utils.transform as transform

32
33
from . import image as fslimage

34
35
36
37

log = logging.getLogger(__name__)


38
class TriangleMesh(object):
39
40
41
    """The ``TriangleMesh`` class represents a 3D model. A mesh is defined by a
    collection of ``N`` vertices, and ``M`` triangles.  The triangles are
    defined by ``(M, 3)) indices into the list of vertices.
42
43
44
45


    A ``TriangleMesh`` instance has the following attributes:

46

47
48
    ============== ====================================================
    ``name``       A name, typically the file name sans-suffix.
49

50
51
    ``dataSource`` Full path to the mesh file (or ``None`` if there is
                   no file associated with this mesh).
52

Paul McCarthy's avatar
Paul McCarthy committed
53
    ``vertices``   A :math:`N\times 3` ``numpy`` array containing
54
                   the vertices.
55

56
57
    ``indices``    A :meth:`M\times 3` ``numpy`` array containing
                   the vertex indices for :math:`M` triangles
58
59
60
61
62
63

    ``normals``    A :math:`M\times 3` ``numpy`` array containing
                   face normals.

    ``vnormals``   A :math:`N\times 3` ``numpy`` array containing
                   vertex normals.
64
65
    ============== ====================================================

66
67
68
69
70
71
72
73

    And the following methods:

    .. autosummary::
       :nosignatures:

       getBounds
       loadVertexData
74
75
       getVertexData
       clearVertexData
76
    """
77

78

79
    def __init__(self, data, indices=None, fixWinding=False):
80
        """Create a ``TriangleMesh`` instance.
81

82
83
84
85
86
87
88
        :arg data:       Can either be a file name, or a :math:`N\\times 3`
                         ``numpy`` array containing vertex data. If ``data``
                         is a file name, it is passed to the
                         :func:`loadVTKPolydataFile` function.

        :arg indices:    A list of indices into the vertex data, defining
                         the triangles.
89

90
91
92
        :arg fixWinding: Defaults to ``False``. If ``True``, the vertex
                         winding order of every triangle is is fixed so they
                         all have outward-facing normal vectors.
93
94
        """

Paul McCarthy's avatar
Paul McCarthy committed
95
        if isinstance(data, six.string_types):
96
97
98
99
100
101
            infile = data
            data, lengths, indices = loadVTKPolydataFile(infile)

            if np.any(lengths != 3):
                raise RuntimeError('All polygons in VTK file must be '
                                   'triangles ({})'.format(infile))
102
103
104
105

            self.name       = op.basename(infile)
            self.dataSource = infile
        else:
106
107
            self.name       = 'TriangleMesh'
            self.dataSource = None
108

109
        if indices is None:
110
            indices = np.arange(data.shape[0])
111

112
113
        self.__vertices     = np.array(data)
        self.__indices      = np.array(indices).reshape((-1, 3))
114

115
        self.__vertexData = {}
116
117
118
119
        self.__faceNormals = None
        self.__vertNormals = None
        self.__loBounds    = self.vertices.min(axis=0)
        self.__hiBounds    = self.vertices.max(axis=0)
120

121
122
123
        if fixWinding:
            self.__fixWindingOrder()

124

125
    def __repr__(self):
126
127
        """Returns a string representation of this ``TriangleMesh`` instance.
        """
128
129
130
131
132
        return '{}({}, {})'.format(type(self).__name__,
                                   self.name,
                                   self.dataSource)

    def __str__(self):
133
134
        """Returns a string representation of this ``TriangleMesh`` instance.
        """
135
136
137
        return self.__repr__()


138
    @property
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    def vertices(self):
        """The ``(N, 3)`` vertices of this mesh. """
        return self.__vertices


    @property
    def indices(self):
        """The ``(M, 3)`` triangles of this mesh. """
        return self.__indices


    def __fixWindingOrder(self):
        """Called by :meth:`__init__` if ``fixWinding is True``.  Fixes the
        mesh triangle winding order so that all face normals are facing
        outwards from the centre of the mesh.
154
155
        """

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        # Define a viewpoint which is
        # far away from the mesh.
        fnormals = self.normals
        camera   = self.__loBounds - (self.__hiBounds - self.__loBounds)

        # Find the nearest vertex
        # to the viewpoint
        dists = np.sqrt(np.sum((self.vertices - camera) ** 2, axis=1))
        ivert = np.argmin(dists)
        vert  = self.vertices[ivert]

        # Pick a triangle that
        # this vertex in and
        # ges its face normal
        itri = np.where(self.indices == ivert)[0][0]
        n    = fnormals[itri, :]

        # Make sure the angle between the
        # normal, and a vector from the
        # vertex to the camera is positive
        # If it isn't, flip the triangle
        # winding order.
        if np.dot(n, transform.normalise(camera - vert)) < 0:
            self.indices[:, [1, 2]] = self.indices[:, [2, 1]]
            self.__faceNormals     *= -1
181
182


183
184
185
186
187
    @property
    def normals(self):
        """A ``(M, 3)`` array containing surface normals for every
        triangle in the mesh, normalised to unit length.
        """
188

189
190
191
192
193
194
        if self.__faceNormals is not None:
            return self.__faceNormals

        v0 = self.vertices[self.indices[:, 0]]
        v1 = self.vertices[self.indices[:, 1]]
        v2 = self.vertices[self.indices[:, 2]]
195

196
        n = np.cross((v1 - v0), (v2 - v0))
197

198
        self.__faceNormals = transform.normalise(n)
199
200
201
202
203
204

        return self.__faceNormals


    @property
    def vnormals(self):
205
        """A ``(N, 3)`` array containing normals for every vertex
206
207
208
209
210
211
212
213
214
        in the mesh.
        """
        if self.__vertNormals is not None:
            return self.__vertNormals

        # per-face normals
        fnormals = self.normals
        vnormals = np.zeros((self.vertices.shape[0], 3), dtype=np.float)

215
216
217
218
        # TODO make fast. I can't figure
        # out how to use np.add.at to
        # accumulate the face normals for
        # each vertex.
219
220
221
222
223
224
225
226
227
        for i in range(self.indices.shape[0]):

            v0, v1, v2 = self.indices[i]

            vnormals[v0, :] += fnormals[i]
            vnormals[v1, :] += fnormals[i]
            vnormals[v2, :] += fnormals[i]

        # normalise to unit length
228
        self.__vertNormals = transform.normalise(vnormals)
229
230
231
232

        return self.__vertNormals


233
    def getBounds(self):
234
        """Returns a tuple of values which define a minimal bounding box that
235
        will contain all vertices in this ``TriangleMesh`` instance. The
236
        bounding box is arranged like so:
237
238
239

            ``((xlow, ylow, zlow), (xhigh, yhigh, zhigh))``
        """
240
        return (self.__loBounds, self.__hiBounds)
241
242


243
    def loadVertexData(self, dataSource, vertexData=None):
244
        """Attempts to load scalar data associated with each vertex of this
245
246
247
        ``TriangleMesh`` from the given ``dataSource``. The data is returned,
        and also stored in an internal cache so it can be retrieved later
        via the :meth:`getVertexData` method.
248
249

        This method may be overridden by sub-classes.
250
251
252
253

        :arg dataSource: Path to the vertex data to load
        :arg vertexData: The vertex data itself, if it has already been
                         loaded.
254
255
256

        :returns: A ``(M, N)``) array, which contains ``N`` data points
                  for ``M`` vertices.
257
        """
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        nvertices = self.vertices.shape[0]

        # Currently only white-space delimited
        # text files are supported
        if vertexData is None:
            vertexData = np.loadtxt(dataSource)
            vertexData.reshape(nvertices, -1)

        if vertexData.shape[0] != nvertices:
            raise ValueError('Incompatible size: {}'.format(dataSource))

        self.__vertexData[dataSource] = vertexData

        return vertexData


    def getVertexData(self, dataSource):
        """Returns the vertex data for the given ``dataSource`` from the
        internal vertex data cache. If the given ``dataSource`` is not
        in the cache, it is loaded via :meth:`loadVertexData`.
        """

        try:             return self.__vertexData[dataSource]
        except KeyError: return self.loadVertexData(dataSource)

284

285
286
287
288
289
290
    def clearVertexData(self):
        """Clears the internal vertex data cache - see the
        :meth:`loadVertexData` and :meth:`getVertexData`  methods.
        """

        self.__vertexData = {}
291
292


293
ALLOWED_EXTENSIONS     = ['.vtk']
294
295
"""A list of file extensions which could contain :class:`TriangleMesh` data.
"""
296
297
298
299
300
301
302
303
304
305
306
307


EXTENSION_DESCRIPTIONS = ['VTK polygon model file']
"""A description for each of the extensions in :data:`ALLOWED_EXTENSIONS`."""


def loadVTKPolydataFile(infile):
    """Loads a vtk legacy file containing a ``POLYDATA`` data set.

    :arg infile: Name of a file to load from.

    :returns: a tuple containing three values:
308

309
310
311
312
313
314
                - A :math:`N\\times 3` ``numpy`` array containing :math:`N`
                  vertices.
                - A 1D ``numpy`` array containing the lengths of each polygon.
                - A 1D ``numpy`` array containing the vertex indices for all
                  polygons.
    """
315

316
317
318
319
320
321
322
323
324
    lines = None

    with open(infile, 'rt') as f:
        lines = f.readlines()

    lines = [l.strip() for l in lines]

    if lines[3] != 'DATASET POLYDATA':
        raise ValueError('Only the POLYDATA data type is supported')
325

326
327
    nVertices = int(lines[4].split()[1])
    nPolygons = int(lines[5 + nVertices].split()[1])
328
329
    nIndices  = int(lines[5 + nVertices].split()[2]) - nPolygons

330
331
332
333
334
335
    vertices       = np.zeros((nVertices, 3), dtype=np.float32)
    polygonLengths = np.zeros( nPolygons,     dtype=np.uint32)
    indices        = np.zeros( nIndices,      dtype=np.uint32)

    for i in range(nVertices):
        vertLine       = lines[i + 5]
Paul McCarthy's avatar
Paul McCarthy committed
336
        vertices[i, :] = [float(w) for w in vertLine.split()]
337
338
339
340
341
342
343
344
345

    indexOffset = 0
    for i in range(nPolygons):

        polyLine          = lines[6 + nVertices + i].split()
        polygonLengths[i] = int(polyLine[0])

        start              = indexOffset
        end                = indexOffset + polygonLengths[i]
Paul McCarthy's avatar
Paul McCarthy committed
346
        indices[start:end] = [int(w) for w in polyLine[1:]]
347
348
349
350

        indexOffset        += polygonLengths[i]

    return vertices, polygonLengths, indices
351
352
353
354
355


def getFIRSTPrefix(modelfile):
    """If the given ``vtk`` file was generated by `FIRST
    <https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/FIRST>`_, this function
356
    will return the file prefix. Otherwise a ``ValueError`` will be
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    raised.
    """

    if not modelfile.endswith('first.vtk'):
        raise ValueError('Not a first vtk file: {}'.format(modelfile))

    modelfile = op.basename(modelfile)
    prefix    = modelfile.split('-')
    prefix    = '-'.join(prefix[:-1])

    return prefix


def findReferenceImage(modelfile):
    """Given a ``vtk`` file, attempts to find a corresponding ``NIFTI``
    image file. Return the path to the image, or ``None`` if no image was
    found.

    Currently this function will only return an image for ``vtk`` files
    generated by `FIRST <https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/FIRST>`_.
    """

    try:

381
382
        dirname  = op.dirname(modelfile)
        prefixes = [getFIRSTPrefix(modelfile)]
Paul McCarthy's avatar
Paul McCarthy committed
383
    except ValueError:
384
        return None
385
386
387
388
389
390
391

    if prefixes[0].endswith('_first'):
        prefixes.append(prefixes[0][:-6])

    for p in prefixes:
        try:
            return fslimage.addExt(op.join(dirname, p), mustExist=True)
Paul McCarthy's avatar
Paul McCarthy committed
392
        except fslimage.PathError:
393
394
395
            continue

    return None