mesh.py 17.4 KB
Newer Older
1
2
#!/usr/bin/env python
#
3
# mesh.py - The TriangleMesh class.
4
5
6
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
7
8
"""This module provides the :class:`TriangleMesh` class, which represents a
3D model made of triangles.
9

10
.. note:: I/O support is very limited - currently, the only supported file
11
          type is the VTK legacy file format, containing the ``POLYDATA``
12
13
14
          dataset. the :class:`TriangleMesh` class assumes that every polygon
          defined in an input file is a triangle (i.e. refers to three
          vertices).
15
16
17

          See http://www.vtk.org/wp-content/uploads/2015/04/file-formats.pdf
          for an overview of the VTK legacy file format.
18
19

          In the future, I may or may not add support for more complex meshes.
20
21
"""

22

23
24
import logging

25
26
import os.path as op
import numpy   as np
27

Paul McCarthy's avatar
Paul McCarthy committed
28
29
import six

30
import fsl.utils.meta      as meta
31
import fsl.utils.memoize   as memoize
32
33
import fsl.utils.transform as transform

34
35
from . import image as fslimage

36
37
38
39

log = logging.getLogger(__name__)


40
class TriangleMesh(meta.Meta):
41
42
    """The ``TriangleMesh`` class represents a 3D model. A mesh is defined by a
    collection of ``N`` vertices, and ``M`` triangles.  The triangles are
Paul McCarthy's avatar
Paul McCarthy committed
43
    defined by ``(M, 3)`` indices into the list of vertices.
44
45
46
47


    A ``TriangleMesh`` instance has the following attributes:

48

49
50
    ============== ====================================================
    ``name``       A name, typically the file name sans-suffix.
51

52
53
    ``dataSource`` Full path to the mesh file (or ``None`` if there is
                   no file associated with this mesh).
54

Paul McCarthy's avatar
Paul McCarthy committed
55
    ``vertices``   A :math:`N\times 3` ``numpy`` array containing
56
                   the vertices.
57

58
59
    ``indices``    A :meth:`M\times 3` ``numpy`` array containing
                   the vertex indices for :math:`M` triangles
60
61
62
63
64
65

    ``normals``    A :math:`M\times 3` ``numpy`` array containing
                   face normals.

    ``vnormals``   A :math:`N\times 3` ``numpy`` array containing
                   vertex normals.
66
67
    ============== ====================================================

68
69
70
71
72
73
74
75

    And the following methods:

    .. autosummary::
       :nosignatures:

       getBounds
       loadVertexData
76
77
       getVertexData
       clearVertexData
78
    """
79

80

81
    def __init__(self, data, indices=None, fixWinding=False):
82
        """Create a ``TriangleMesh`` instance.
83

84
85
86
87
88
89
90
        :arg data:       Can either be a file name, or a :math:`N\\times 3`
                         ``numpy`` array containing vertex data. If ``data``
                         is a file name, it is passed to the
                         :func:`loadVTKPolydataFile` function.

        :arg indices:    A list of indices into the vertex data, defining
                         the triangles.
91

92
93
94
        :arg fixWinding: Defaults to ``False``. If ``True``, the vertex
                         winding order of every triangle is is fixed so they
                         all have outward-facing normal vectors.
95
96
        """

Paul McCarthy's avatar
Paul McCarthy committed
97
        if isinstance(data, six.string_types):
98
99
100
101
102
103
            infile = data
            data, lengths, indices = loadVTKPolydataFile(infile)

            if np.any(lengths != 3):
                raise RuntimeError('All polygons in VTK file must be '
                                   'triangles ({})'.format(infile))
104
105
106
107

            self.name       = op.basename(infile)
            self.dataSource = infile
        else:
108
109
            self.name       = 'TriangleMesh'
            self.dataSource = None
110

111
        if indices is None:
112
            indices = np.arange(data.shape[0])
113

114
115
        self.__vertices     = np.array(data)
        self.__indices      = np.array(indices).reshape((-1, 3))
116

117
        self.__vertexData = {}
118
119
120
121
        self.__faceNormals = None
        self.__vertNormals = None
        self.__loBounds    = self.vertices.min(axis=0)
        self.__hiBounds    = self.vertices.max(axis=0)
122

123
124
125
        if fixWinding:
            self.__fixWindingOrder()

126

127
    def __repr__(self):
128
129
        """Returns a string representation of this ``TriangleMesh`` instance.
        """
130
131
132
133
134
        return '{}({}, {})'.format(type(self).__name__,
                                   self.name,
                                   self.dataSource)

    def __str__(self):
135
136
        """Returns a string representation of this ``TriangleMesh`` instance.
        """
137
138
139
        return self.__repr__()


140
141
142
143
    def __fixWindingOrder(self):
        """Called by :meth:`__init__` if ``fixWinding is True``.  Fixes the
        mesh triangle winding order so that all face normals are facing
        outwards from the centre of the mesh.
144
145
        """

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        # Define a viewpoint which is
        # far away from the mesh.
        fnormals = self.normals
        camera   = self.__loBounds - (self.__hiBounds - self.__loBounds)

        # Find the nearest vertex
        # to the viewpoint
        dists = np.sqrt(np.sum((self.vertices - camera) ** 2, axis=1))
        ivert = np.argmin(dists)
        vert  = self.vertices[ivert]

        # Pick a triangle that
        # this vertex in and
        # ges its face normal
        itri = np.where(self.indices == ivert)[0][0]
        n    = fnormals[itri, :]

        # Make sure the angle between the
        # normal, and a vector from the
        # vertex to the camera is positive
        # If it isn't, flip the triangle
        # winding order.
        if np.dot(n, transform.normalise(camera - vert)) < 0:
            self.indices[:, [1, 2]] = self.indices[:, [2, 1]]
            self.__faceNormals     *= -1
171
172


173
174
175
176
177
178
179
180
181
182
183
184
    @property
    def vertices(self):
        """The ``(N, 3)`` vertices of this mesh. """
        return self.__vertices


    @property
    def indices(self):
        """The ``(M, 3)`` triangles of this mesh. """
        return self.__indices


185
186
187
188
189
    @property
    def normals(self):
        """A ``(M, 3)`` array containing surface normals for every
        triangle in the mesh, normalised to unit length.
        """
190

191
192
193
194
195
196
        if self.__faceNormals is not None:
            return self.__faceNormals

        v0 = self.vertices[self.indices[:, 0]]
        v1 = self.vertices[self.indices[:, 1]]
        v2 = self.vertices[self.indices[:, 2]]
197

198
        n = np.cross((v1 - v0), (v2 - v0))
199

200
        self.__faceNormals = transform.normalise(n)
201
202
203
204
205
206

        return self.__faceNormals


    @property
    def vnormals(self):
207
        """A ``(N, 3)`` array containing normals for every vertex
208
209
210
211
212
213
214
215
216
        in the mesh.
        """
        if self.__vertNormals is not None:
            return self.__vertNormals

        # per-face normals
        fnormals = self.normals
        vnormals = np.zeros((self.vertices.shape[0], 3), dtype=np.float)

217
218
219
220
        # TODO make fast. I can't figure
        # out how to use np.add.at to
        # accumulate the face normals for
        # each vertex.
221
222
223
224
225
226
227
228
229
        for i in range(self.indices.shape[0]):

            v0, v1, v2 = self.indices[i]

            vnormals[v0, :] += fnormals[i]
            vnormals[v1, :] += fnormals[i]
            vnormals[v2, :] += fnormals[i]

        # normalise to unit length
230
        self.__vertNormals = transform.normalise(vnormals)
231
232
233
234

        return self.__vertNormals


235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    def getBounds(self):
        """Returns a tuple of values which define a minimal bounding box that
        will contain all vertices in this ``TriangleMesh`` instance. The
        bounding box is arranged like so:

            ``((xlow, ylow, zlow), (xhigh, yhigh, zhigh))``
        """
        return (self.__loBounds, self.__hiBounds)


    def loadVertexData(self, dataSource, vertexData=None):
        """Attempts to load scalar data associated with each vertex of this
        ``TriangleMesh`` from the given ``dataSource``. The data is returned,
        and also stored in an internal cache so it can be retrieved later
        via the :meth:`getVertexData` method.

        This method may be overridden by sub-classes.

        :arg dataSource: Path to the vertex data to load
        :arg vertexData: The vertex data itself, if it has already been
                         loaded.

        :returns: A ``(M, N)``) array, which contains ``N`` data points
                  for ``M`` vertices.
        """

        nvertices = self.vertices.shape[0]

        # Currently only white-space delimited
        # text files are supported
        if vertexData is None:
            vertexData = np.loadtxt(dataSource)
            vertexData.reshape(nvertices, -1)

        if vertexData.shape[0] != nvertices:
            raise ValueError('Incompatible size: {}'.format(dataSource))

        self.__vertexData[dataSource] = vertexData

        return vertexData


    def getVertexData(self, dataSource):
        """Returns the vertex data for the given ``dataSource`` from the
        internal vertex data cache. If the given ``dataSource`` is not
        in the cache, it is loaded via :meth:`loadVertexData`.
        """

        try:             return self.__vertexData[dataSource]
        except KeyError: return self.loadVertexData(dataSource)


    def clearVertexData(self):
        """Clears the internal vertex data cache - see the
        :meth:`loadVertexData` and :meth:`getVertexData`  methods.
        """

        self.__vertexData = {}


295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    @memoize.Instanceify(memoize.memoize)
    def trimesh(self):
        """Reference to a ``trimesh.Trimesh`` object which can be used for
        geometric operations on the mesh.

        If the ``trimesh`` or ``rtree`` libraries are not available, this
        function returns ``None``
        """

        # trimesh is an optional dependency - rtree
        # is a depedendency of trimesh which is a
        # wrapper around libspatialindex, without
        # which trimesh can't be used for calculating
        # ray-mesh intersections.
        try:
            import trimesh
            import rtree   # noqa
        except ImportError:
            log.warning('trimesh is not available')
            return None

        if hasattr(self, '__trimesh'):
            return self.__trimesh

        self.__trimesh = trimesh.Trimesh(self.__vertices,
                                         self.__indices,
                                         process=False,
                                         validate=False)

        return self.__trimesh


327
    def rayIntersection(self, origins, directions, vertices=False):
328
329
330
331
332
        """Calculate the intersection between the mesh, and the rays defined by
        ``origins`` and ``directions``.

        :arg origins:    Sequence of ray origins
        :arg directions: Sequence of ray directions
333
334
335
336
337
338
339
340
341
        :returns:        A tuple containing:

                           - A ``(n, 3)`` array containing the coordinates
                             where the mesh was intersected by each of the
                             ``n`` rays.

                           - A ``(n,)`` array containing the indices of the
                             triangles that were intersected by each of the
                             ``n`` rays.
342
343
344
345
346
        """

        trimesh = self.trimesh()

        if trimesh is None:
347
            return np.zeros((0, 3)), np.zeros((0,))
348

349
        tris, rays, locs = trimesh.ray.intersects_id(
350
351
352
            origins,
            directions,
            return_locations=True,
353
            multiple_hits=False)
354

Paul McCarthy's avatar
Paul McCarthy committed
355
        if len(tris) == 0:
356
357
            return np.zeros((0, 3)), np.zeros((0,))

358
359
        # sort by ray. I'm Not sure if this is
        # needed - does trimesh do it for us?
360
        rayIdxs = np.asarray(np.argsort(rays), np.int)
361
362
        locs    = locs[rayIdxs]
        tris    = tris[rayIdxs]
363

364
        return locs, tris
365
366


367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    def nearestVertex(self, points):
        """Identifies the nearest vertex to each of the provided points.

        :arg points: A ``(n, 3)`` array containing the points to query.

        :returns:    A tuple containing:

                      - A ``(n, 3)`` array containing the nearest vertex for
                        for each of the ``n`` input points.

                      - A ``(n,)`` array containing the indices of each vertex.

                      - A ``(n,)`` array containing the distance from each
                        point to the nearest vertex.
        """

        trimesh = self.trimesh()

        if trimesh is None:
386
            return np.zeros((0, 3)), np.zeros((0, )), np.zeros((0, ))
387
388
389
390
391
392
393

        dists, idxs = trimesh.nearest.vertex(points)
        verts       = self.vertices[idxs, :]

        return verts, idxs, dists


394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    def planeIntersection(self,
                          normal,
                          origin,
                          distances=False):
        """Calculate the intersection of this ``TriangleMesh`` with
        the plane defined by ``normal`` and ``origin``.

        :arg normal:    Vector defining the plane orientation

        :arg origin:    Point defining the plane location

        :arg distances: If ``True``, barycentric coordinates for each
                        intersection line vertex are calculated and returned,
                        giving their respective distance from the intersected
                        triangle vertices.

        :returns:       A tuple containing
                          - A ``(m, 2, 3)`` array containing ``m`` vertices:
                            of a set of lines, defining the plane intersection

                          - A ``(m,)`` array containing the indices of the
                            ``m`` triangles that were intersected.

                          - (if ``distances is True``) A ``(m, 2, 3)`` arra
                            containing the barycentric coordinates of each
                            line vertex with respect to its intersected
                            triangle.
        """

        trimesh = self.trimesh()

        if trimesh is None:
            return np.zeros((0, 3)), np.zeros((0, 3))

        import trimesh.intersections as tmint
        import trimesh.triangles     as tmtri

        lines, faces = tmint.mesh_plane(
            trimesh,
            plane_normal=normal,
            plane_origin=origin,
            return_faces=True)

        if not distances:
            return lines, faces

        # Calculate the barycentric coordinates
        # (distance from triangle vertices) for
        # each intersection line

        triangles = self.vertices[self.indices[faces]].repeat(2, axis=0)
        points    = lines.reshape((-1, 3))

        if triangles.size > 0:
            dists = tmtri.points_to_barycentric(triangles, points)
            dists = dists.reshape((-1, 2, 3))
        else:
            dists = np.zeros((0, 2, 3))

        return lines, faces, dists


456

457
ALLOWED_EXTENSIONS     = ['.vtk']
458
459
"""A list of file extensions which could contain :class:`TriangleMesh` data.
"""
460
461
462
463
464
465
466
467
468
469
470
471


EXTENSION_DESCRIPTIONS = ['VTK polygon model file']
"""A description for each of the extensions in :data:`ALLOWED_EXTENSIONS`."""


def loadVTKPolydataFile(infile):
    """Loads a vtk legacy file containing a ``POLYDATA`` data set.

    :arg infile: Name of a file to load from.

    :returns: a tuple containing three values:
472

473
474
475
476
477
478
                - A :math:`N\\times 3` ``numpy`` array containing :math:`N`
                  vertices.
                - A 1D ``numpy`` array containing the lengths of each polygon.
                - A 1D ``numpy`` array containing the vertex indices for all
                  polygons.
    """
479

480
481
482
483
484
485
486
487
488
    lines = None

    with open(infile, 'rt') as f:
        lines = f.readlines()

    lines = [l.strip() for l in lines]

    if lines[3] != 'DATASET POLYDATA':
        raise ValueError('Only the POLYDATA data type is supported')
489

490
491
    nVertices = int(lines[4].split()[1])
    nPolygons = int(lines[5 + nVertices].split()[1])
492
493
    nIndices  = int(lines[5 + nVertices].split()[2]) - nPolygons

494
495
496
497
498
499
    vertices       = np.zeros((nVertices, 3), dtype=np.float32)
    polygonLengths = np.zeros( nPolygons,     dtype=np.uint32)
    indices        = np.zeros( nIndices,      dtype=np.uint32)

    for i in range(nVertices):
        vertLine       = lines[i + 5]
Paul McCarthy's avatar
Paul McCarthy committed
500
        vertices[i, :] = [float(w) for w in vertLine.split()]
501
502
503
504
505
506
507
508
509

    indexOffset = 0
    for i in range(nPolygons):

        polyLine          = lines[6 + nVertices + i].split()
        polygonLengths[i] = int(polyLine[0])

        start              = indexOffset
        end                = indexOffset + polygonLengths[i]
Paul McCarthy's avatar
Paul McCarthy committed
510
        indices[start:end] = [int(w) for w in polyLine[1:]]
511
512
513
514

        indexOffset        += polygonLengths[i]

    return vertices, polygonLengths, indices
515
516
517
518
519


def getFIRSTPrefix(modelfile):
    """If the given ``vtk`` file was generated by `FIRST
    <https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/FIRST>`_, this function
520
    will return the file prefix. Otherwise a ``ValueError`` will be
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    raised.
    """

    if not modelfile.endswith('first.vtk'):
        raise ValueError('Not a first vtk file: {}'.format(modelfile))

    modelfile = op.basename(modelfile)
    prefix    = modelfile.split('-')
    prefix    = '-'.join(prefix[:-1])

    return prefix


def findReferenceImage(modelfile):
    """Given a ``vtk`` file, attempts to find a corresponding ``NIFTI``
    image file. Return the path to the image, or ``None`` if no image was
    found.

    Currently this function will only return an image for ``vtk`` files
    generated by `FIRST <https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/FIRST>`_.
    """

    try:

545
546
        dirname  = op.dirname(modelfile)
        prefixes = [getFIRSTPrefix(modelfile)]
Paul McCarthy's avatar
Paul McCarthy committed
547
    except ValueError:
548
        return None
549
550
551
552
553
554
555

    if prefixes[0].endswith('_first'):
        prefixes.append(prefixes[0][:-6])

    for p in prefixes:
        try:
            return fslimage.addExt(op.join(dirname, p), mustExist=True)
Paul McCarthy's avatar
Paul McCarthy committed
556
        except fslimage.PathError:
557
558
559
            continue

    return None