mesh.py 14.3 KB
Newer Older
1
2
#!/usr/bin/env python
#
3
# mesh.py - The TriangleMesh class.
4
5
6
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
7
8
"""This module provides the :class:`TriangleMesh` class, which represents a
3D model made of triangles.
9

10
.. note:: I/O support is very limited - currently, the only supported file
11
          type is the VTK legacy file format, containing the ``POLYDATA``
12
13
14
          dataset. the :class:`TriangleMesh` class assumes that every polygon
          defined in an input file is a triangle (i.e. refers to three
          vertices).
15
16
17

          See http://www.vtk.org/wp-content/uploads/2015/04/file-formats.pdf
          for an overview of the VTK legacy file format.
18
19

          In the future, I may or may not add support for more complex meshes.
20
21
"""

22

23
24
import logging

25
26
import os.path as op
import numpy   as np
27

Paul McCarthy's avatar
Paul McCarthy committed
28
29
import six

30
import fsl.utils.memoize   as memoize
31
32
import fsl.utils.transform as transform

33
34
from . import image as fslimage

35
36
37
38

log = logging.getLogger(__name__)


39
class TriangleMesh(object):
40
41
    """The ``TriangleMesh`` class represents a 3D model. A mesh is defined by a
    collection of ``N`` vertices, and ``M`` triangles.  The triangles are
Paul McCarthy's avatar
Paul McCarthy committed
42
    defined by ``(M, 3)`` indices into the list of vertices.
43
44
45
46


    A ``TriangleMesh`` instance has the following attributes:

47

48
49
    ============== ====================================================
    ``name``       A name, typically the file name sans-suffix.
50

51
52
    ``dataSource`` Full path to the mesh file (or ``None`` if there is
                   no file associated with this mesh).
53

Paul McCarthy's avatar
Paul McCarthy committed
54
    ``vertices``   A :math:`N\times 3` ``numpy`` array containing
55
                   the vertices.
56

57
58
    ``indices``    A :meth:`M\times 3` ``numpy`` array containing
                   the vertex indices for :math:`M` triangles
59
60
61
62
63
64

    ``normals``    A :math:`M\times 3` ``numpy`` array containing
                   face normals.

    ``vnormals``   A :math:`N\times 3` ``numpy`` array containing
                   vertex normals.
65
66
    ============== ====================================================

67
68
69
70
71
72
73
74

    And the following methods:

    .. autosummary::
       :nosignatures:

       getBounds
       loadVertexData
75
76
       getVertexData
       clearVertexData
77
    """
78

79

80
    def __init__(self, data, indices=None, fixWinding=False):
81
        """Create a ``TriangleMesh`` instance.
82

83
84
85
86
87
88
89
        :arg data:       Can either be a file name, or a :math:`N\\times 3`
                         ``numpy`` array containing vertex data. If ``data``
                         is a file name, it is passed to the
                         :func:`loadVTKPolydataFile` function.

        :arg indices:    A list of indices into the vertex data, defining
                         the triangles.
90

91
92
93
        :arg fixWinding: Defaults to ``False``. If ``True``, the vertex
                         winding order of every triangle is is fixed so they
                         all have outward-facing normal vectors.
94
95
        """

Paul McCarthy's avatar
Paul McCarthy committed
96
        if isinstance(data, six.string_types):
97
98
99
100
101
102
            infile = data
            data, lengths, indices = loadVTKPolydataFile(infile)

            if np.any(lengths != 3):
                raise RuntimeError('All polygons in VTK file must be '
                                   'triangles ({})'.format(infile))
103
104
105
106

            self.name       = op.basename(infile)
            self.dataSource = infile
        else:
107
108
            self.name       = 'TriangleMesh'
            self.dataSource = None
109

110
        if indices is None:
111
            indices = np.arange(data.shape[0])
112

113
114
        self.__vertices     = np.array(data)
        self.__indices      = np.array(indices).reshape((-1, 3))
115

116
        self.__vertexData = {}
117
118
119
120
        self.__faceNormals = None
        self.__vertNormals = None
        self.__loBounds    = self.vertices.min(axis=0)
        self.__hiBounds    = self.vertices.max(axis=0)
121

122
123
124
        if fixWinding:
            self.__fixWindingOrder()

125

126
    def __repr__(self):
127
128
        """Returns a string representation of this ``TriangleMesh`` instance.
        """
129
130
131
132
133
        return '{}({}, {})'.format(type(self).__name__,
                                   self.name,
                                   self.dataSource)

    def __str__(self):
134
135
        """Returns a string representation of this ``TriangleMesh`` instance.
        """
136
137
138
        return self.__repr__()


139
140
141
142
    def __fixWindingOrder(self):
        """Called by :meth:`__init__` if ``fixWinding is True``.  Fixes the
        mesh triangle winding order so that all face normals are facing
        outwards from the centre of the mesh.
143
144
        """

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        # Define a viewpoint which is
        # far away from the mesh.
        fnormals = self.normals
        camera   = self.__loBounds - (self.__hiBounds - self.__loBounds)

        # Find the nearest vertex
        # to the viewpoint
        dists = np.sqrt(np.sum((self.vertices - camera) ** 2, axis=1))
        ivert = np.argmin(dists)
        vert  = self.vertices[ivert]

        # Pick a triangle that
        # this vertex in and
        # ges its face normal
        itri = np.where(self.indices == ivert)[0][0]
        n    = fnormals[itri, :]

        # Make sure the angle between the
        # normal, and a vector from the
        # vertex to the camera is positive
        # If it isn't, flip the triangle
        # winding order.
        if np.dot(n, transform.normalise(camera - vert)) < 0:
            self.indices[:, [1, 2]] = self.indices[:, [2, 1]]
            self.__faceNormals     *= -1
170
171


172
173
174
175
176
177
178
179
180
181
182
183
    @property
    def vertices(self):
        """The ``(N, 3)`` vertices of this mesh. """
        return self.__vertices


    @property
    def indices(self):
        """The ``(M, 3)`` triangles of this mesh. """
        return self.__indices


184
185
186
187
188
    @property
    def normals(self):
        """A ``(M, 3)`` array containing surface normals for every
        triangle in the mesh, normalised to unit length.
        """
189

190
191
192
193
194
195
        if self.__faceNormals is not None:
            return self.__faceNormals

        v0 = self.vertices[self.indices[:, 0]]
        v1 = self.vertices[self.indices[:, 1]]
        v2 = self.vertices[self.indices[:, 2]]
196

197
        n = np.cross((v1 - v0), (v2 - v0))
198

199
        self.__faceNormals = transform.normalise(n)
200
201
202
203
204
205

        return self.__faceNormals


    @property
    def vnormals(self):
206
        """A ``(N, 3)`` array containing normals for every vertex
207
208
209
210
211
212
213
214
215
        in the mesh.
        """
        if self.__vertNormals is not None:
            return self.__vertNormals

        # per-face normals
        fnormals = self.normals
        vnormals = np.zeros((self.vertices.shape[0], 3), dtype=np.float)

216
217
218
219
        # TODO make fast. I can't figure
        # out how to use np.add.at to
        # accumulate the face normals for
        # each vertex.
220
221
222
223
224
225
226
227
228
        for i in range(self.indices.shape[0]):

            v0, v1, v2 = self.indices[i]

            vnormals[v0, :] += fnormals[i]
            vnormals[v1, :] += fnormals[i]
            vnormals[v2, :] += fnormals[i]

        # normalise to unit length
229
        self.__vertNormals = transform.normalise(vnormals)
230
231
232
233

        return self.__vertNormals


234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    @memoize.Instanceify(memoize.memoize)
    def trimesh(self):
        """Reference to a ``trimesh.Trimesh`` object which can be used for
        geometric operations on the mesh.

        If the ``trimesh`` or ``rtree`` libraries are not available, this
        function returns ``None``
        """

        # trimesh is an optional dependency - rtree
        # is a depedendency of trimesh which is a
        # wrapper around libspatialindex, without
        # which trimesh can't be used for calculating
        # ray-mesh intersections.
        try:
            import trimesh
            import rtree   # noqa
        except ImportError:
            log.warning('trimesh is not available')
            return None

        if hasattr(self, '__trimesh'):
            return self.__trimesh

        self.__trimesh = trimesh.Trimesh(self.__vertices,
                                         self.__indices,
                                         process=False,
                                         validate=False)

        return self.__trimesh


266
    def rayIntersection(self, origins, directions, vertices=False):
267
268
269
270
271
        """Calculate the intersection between the mesh, and the rays defined by
        ``origins`` and ``directions``.

        :arg origins:    Sequence of ray origins
        :arg directions: Sequence of ray directions
272
273
274
275
276
277
278
279
280
        :returns:        A tuple containing:

                           - A ``(n, 3)`` array containing the coordinates
                             where the mesh was intersected by each of the
                             ``n`` rays.

                           - A ``(n,)`` array containing the indices of the
                             triangles that were intersected by each of the
                             ``n`` rays.
281
282
283
284
285
        """

        trimesh = self.trimesh()

        if trimesh is None:
286
            return np.zeros((0, 3)), np.zeros((0,))
287

288
        tris, rays, locs = trimesh.ray.intersects_id(
289
290
291
            origins,
            directions,
            return_locations=True,
292
            multiple_hits=False)
293

294
295
296
        if tris.size == 0:
            return np.zeros((0, 3)), np.zeros((0,))

297
298
        # sort by ray. I'm Not sure if this is
        # needed - does trimesh do it for us?
299
        rayIdxs = np.asarray(np.argsort(rays), np.int)
300
301
        locs    = locs[rayIdxs]
        tris    = tris[rayIdxs]
302

303
        return locs, tris
304
305


306
    def getBounds(self):
307
        """Returns a tuple of values which define a minimal bounding box that
308
        will contain all vertices in this ``TriangleMesh`` instance. The
309
        bounding box is arranged like so:
310
311
312

            ``((xlow, ylow, zlow), (xhigh, yhigh, zhigh))``
        """
313
        return (self.__loBounds, self.__hiBounds)
314
315


316
    def loadVertexData(self, dataSource, vertexData=None):
317
        """Attempts to load scalar data associated with each vertex of this
318
319
320
        ``TriangleMesh`` from the given ``dataSource``. The data is returned,
        and also stored in an internal cache so it can be retrieved later
        via the :meth:`getVertexData` method.
321
322

        This method may be overridden by sub-classes.
323
324
325
326

        :arg dataSource: Path to the vertex data to load
        :arg vertexData: The vertex data itself, if it has already been
                         loaded.
327
328
329

        :returns: A ``(M, N)``) array, which contains ``N`` data points
                  for ``M`` vertices.
330
        """
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

        nvertices = self.vertices.shape[0]

        # Currently only white-space delimited
        # text files are supported
        if vertexData is None:
            vertexData = np.loadtxt(dataSource)
            vertexData.reshape(nvertices, -1)

        if vertexData.shape[0] != nvertices:
            raise ValueError('Incompatible size: {}'.format(dataSource))

        self.__vertexData[dataSource] = vertexData

        return vertexData


    def getVertexData(self, dataSource):
        """Returns the vertex data for the given ``dataSource`` from the
        internal vertex data cache. If the given ``dataSource`` is not
        in the cache, it is loaded via :meth:`loadVertexData`.
        """

        try:             return self.__vertexData[dataSource]
        except KeyError: return self.loadVertexData(dataSource)

357

358
359
360
361
362
363
    def clearVertexData(self):
        """Clears the internal vertex data cache - see the
        :meth:`loadVertexData` and :meth:`getVertexData`  methods.
        """

        self.__vertexData = {}
364
365


366
ALLOWED_EXTENSIONS     = ['.vtk']
367
368
"""A list of file extensions which could contain :class:`TriangleMesh` data.
"""
369
370
371
372
373
374
375
376
377
378
379
380


EXTENSION_DESCRIPTIONS = ['VTK polygon model file']
"""A description for each of the extensions in :data:`ALLOWED_EXTENSIONS`."""


def loadVTKPolydataFile(infile):
    """Loads a vtk legacy file containing a ``POLYDATA`` data set.

    :arg infile: Name of a file to load from.

    :returns: a tuple containing three values:
381

382
383
384
385
386
387
                - A :math:`N\\times 3` ``numpy`` array containing :math:`N`
                  vertices.
                - A 1D ``numpy`` array containing the lengths of each polygon.
                - A 1D ``numpy`` array containing the vertex indices for all
                  polygons.
    """
388

389
390
391
392
393
394
395
396
397
    lines = None

    with open(infile, 'rt') as f:
        lines = f.readlines()

    lines = [l.strip() for l in lines]

    if lines[3] != 'DATASET POLYDATA':
        raise ValueError('Only the POLYDATA data type is supported')
398

399
400
    nVertices = int(lines[4].split()[1])
    nPolygons = int(lines[5 + nVertices].split()[1])
401
402
    nIndices  = int(lines[5 + nVertices].split()[2]) - nPolygons

403
404
405
406
407
408
    vertices       = np.zeros((nVertices, 3), dtype=np.float32)
    polygonLengths = np.zeros( nPolygons,     dtype=np.uint32)
    indices        = np.zeros( nIndices,      dtype=np.uint32)

    for i in range(nVertices):
        vertLine       = lines[i + 5]
Paul McCarthy's avatar
Paul McCarthy committed
409
        vertices[i, :] = [float(w) for w in vertLine.split()]
410
411
412
413
414
415
416
417
418

    indexOffset = 0
    for i in range(nPolygons):

        polyLine          = lines[6 + nVertices + i].split()
        polygonLengths[i] = int(polyLine[0])

        start              = indexOffset
        end                = indexOffset + polygonLengths[i]
Paul McCarthy's avatar
Paul McCarthy committed
419
        indices[start:end] = [int(w) for w in polyLine[1:]]
420
421
422
423

        indexOffset        += polygonLengths[i]

    return vertices, polygonLengths, indices
424
425
426
427
428


def getFIRSTPrefix(modelfile):
    """If the given ``vtk`` file was generated by `FIRST
    <https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/FIRST>`_, this function
429
    will return the file prefix. Otherwise a ``ValueError`` will be
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    raised.
    """

    if not modelfile.endswith('first.vtk'):
        raise ValueError('Not a first vtk file: {}'.format(modelfile))

    modelfile = op.basename(modelfile)
    prefix    = modelfile.split('-')
    prefix    = '-'.join(prefix[:-1])

    return prefix


def findReferenceImage(modelfile):
    """Given a ``vtk`` file, attempts to find a corresponding ``NIFTI``
    image file. Return the path to the image, or ``None`` if no image was
    found.

    Currently this function will only return an image for ``vtk`` files
    generated by `FIRST <https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/FIRST>`_.
    """

    try:

454
455
        dirname  = op.dirname(modelfile)
        prefixes = [getFIRSTPrefix(modelfile)]
Paul McCarthy's avatar
Paul McCarthy committed
456
    except ValueError:
457
        return None
458
459
460
461
462
463
464

    if prefixes[0].endswith('_first'):
        prefixes.append(prefixes[0][:-6])

    for p in prefixes:
        try:
            return fslimage.addExt(op.join(dirname, p), mustExist=True)
Paul McCarthy's avatar
Paul McCarthy committed
465
        except fslimage.PathError:
466
467
468
            continue

    return None