mesh.py 21.8 KB
Newer Older
1
2
#!/usr/bin/env python
#
3
# mesh.py - The TriangleMesh class.
4
5
6
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
7
"""This module provides the :class:`Mesh` class, which represents a
8
3D model made of triangles.
9

10
See also the following modules:
11

12
  .. autosummary::
13

14
15
16
17
18
19
20
21
22
23
24
25
26
     fsl.data.vtk
     fsl.data.gifti
     fsl.data.freesurfer

A handful of standalone functions are provided in this module, for doing various
things with meshes:

  .. autosummary::
     :nosignatures:

     calcFaceNormals
     calcVertexNormals
     needsFixing
27
28
"""

29

30
import logging
31
32
33
34
import collections

import six
import deprecation
35

36
37
import os.path as op
import numpy   as np
38

39
import fsl.utils.meta      as meta
40
import fsl.utils.notifier  as notifier
41
import fsl.utils.memoize   as memoize
42
import fsl.utils.transform as transform
43
import fsl.data.image      as fslimage
44

45
46
47
48

log = logging.getLogger(__name__)


49
50
class Mesh(notifier.Notifier, meta.Meta):
    """The ``Mesh`` class represents a 3D model. A mesh is defined by a
51
    collection of ``N`` vertices, and ``M`` triangles.  The triangles are
Paul McCarthy's avatar
Paul McCarthy committed
52
    defined by ``(M, 3)`` indices into the list of vertices.
53
54


55
    A ``Mesh`` instance has the following attributes:
56

57

58
59
    ============== ====================================================
    ``name``       A name, typically the file name sans-suffix.
60

61
62
    ``dataSource`` Full path to the mesh file (or ``None`` if there is
                   no file associated with this mesh).
63

64
65
66
67
68
    ``vertices``   A ``(n, 3)`` array containing the currently selected
                   vertices. You can assign  a vertex set key to this
                   attribute to change the selected vertex set.

    ``bounds``     The lower and upper bounds
69

70
71
    ``indices``    A ``(m, 3)`` array containing the vertex indices
                   for ``m`` triangles
72

73
74
    ``normals``    A  ``(m, 3)`` array containing face normals for the
                   triangles
75

76
77
    ``vnormals``   A ``(n, 3)`` array containing vertex normals for the
                   the current vertices.
78
79
    ============== ====================================================

80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    **Vertex sets**


    A ``Mesh`` object can be associated with multiple sets of vertices, but
    only one set of triangles. Vertices can be added via the
    :meth:`addVertices` method. Each vertex set must be associated with a
    unique key - you can then select the current vertex set via the
    :meth:`vertices` property. Most ``Mesh`` methods will raise a ``KeyError``
    if you have not added any vertex sets, or selected a vertex set.


    **Vertex data**


    A ``Mesh`` object can store vertex-wise data. The following methods can be
    used for adding/retrieving vertex data:
97
98
99
100

    .. autosummary::
       :nosignatures:

101
       addVertexData
102
103
       getVertexData
       clearVertexData
104

105

106
    **Notification**
107

108

109
110
111
112
113
114
    The ``Mesh`` class inherits from the :class:`Notifier` class. Whenever the
    ``Mesh`` vertex set is changed, a notification is emitted via the
    ``Notifier`` interface, with a topic of ``'vertices'``. When this occurs,
    the :meth:`vertices`, :meth:`bounds`, :meth:`normals` and :attr:`vnormals`
    properties will all change so that they return data specific to the newly
    selected vertex set.
115
116


117
    **Metadata*
118

119

120
121
    The ``Mesh`` class also inherits from the :class:`Meta` class, so
    any metadata associated with the ``Mesh`` may be added via those methods.
122

123

124
    **Geometric queries**
125

126

127
128
    If the ``trimesh`` library is present, the following methods may be used
    to perform geometric queries on a mesh:
129

130
131
    .. autosummary::
       :nosignatures:
132

133
134
135
136
       rayIntersection
       planeIntersection
       nearestVertex
    """
137
138


139
140
141
142
143
    def __init__(self, indices, name='mesh', dataSource=None):
        """Create a ``Mesh`` instance.

        Before a ``Mesh`` can be used, some vertices must be added via the
        :meth:`addVertices` method.
144

145
146
147
148
149
150
        :arg indices:    A list of indices into the vertex data, defining the
                         mesh triangles.

        :arg name:       A name for this ``Mesh``.

        :arg dataSource: the data source for this ``Mesh``.
151
152
        """

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        self.__name       = name
        self.__dataSource = dataSource
        self.__indices    = np.asarray(indices).reshape((-1, 3))

        # This attribute is used to store
        # the currently selected vertex set,
        # used as a kety into all of the
        # dictionaries below.
        self.__selected = None

        # Flag used to keep track of whether
        # the triangle winding order has been
        # "fixed" - see the addVertices method.
        self.__fixed = False

        # All of these are populated
        # in the addVertices method
        self.__vertices = collections.OrderedDict()
        self.__loBounds = collections.OrderedDict()
        self.__hiBounds = collections.OrderedDict()

        # These get populated on
        # normals/vnormals accesses
        self.__faceNormals = collections.OrderedDict()
        self.__vertNormals = collections.OrderedDict()

        # this gets populated in
        # the addVertexData method
        self.__vertexData  = collections.OrderedDict()

        # this gets populated
        # in the trimesh method
        self.__trimesh = collections.OrderedDict()


188
189
190
191
192
193
194
195
196
197
198
199
200
    def __repr__(self):
        """Returns a string representation of this ``Mesh`` instance. """
        return '{}({}, {})'.format(type(self).__name__,
                                   self.name,
                                   self.dataSource)


    def __str__(self):
        """Returns a string representation of this ``Mesh`` instance.
        """
        return self.__repr__()


201
202
203
204
205
206
207
208
209
210
    @property
    def name(self):
        """Returns the name of this ``Mesh``. """
        return self.__name


    @property
    def dataSource(self):
        """Returns the data source of this ``Mesh``. """
        return self.__dataSource
211
212


213
214
215
    @property
    def vertices(self):
        """The ``(N, 3)`` vertices of this mesh. """
216
217
218
        return self.__vertices[self.__selected]


219
    @vertices.setter
220
221
222
223
224
225
226
227
228
    def vertices(self, key):
        """Select the current vertex set - a ``KeyError`` is raised
        if no vertex set with the specified ``key`` has been added.
        """

        # Force a key error if
        # the key is invalid
        self.__vertices[key]
        self.__selected = key
229

230
231
        self.notify(topic='vertices')

232
233
234
235

    @property
    def indices(self):
        """The ``(M, 3)`` triangles of this mesh. """
236
        return self.__indices
237
238


239
240
241
242
243
    @property
    def normals(self):
        """A ``(M, 3)`` array containing surface normals for every
        triangle in the mesh, normalised to unit length.
        """
244

245
246
247
248
        selected = self.__selected
        indices  = self.__indices
        vertices = self.__vertices[selected]
        fnormals = self.__faceNormals.get(selected, None)
249

250
251
252
        if fnormals is None:
            fnormals = calcFaceNormals(vertices, indices)
            self.__faceNormals[selected] = fnormals
253

254
        return fnormals
255
256
257
258


    @property
    def vnormals(self):
259
        """A ``(N, 3)`` array containing normals for every vertex
260
261
262
        in the mesh.
        """

263
264
265
266
        selected = self.__selected
        indices  = self.__indices
        vertices = self.__vertices[selected]
        vnormals = self.__vertNormals.get(selected, None)
267

268
269
270
        if vnormals is None:
            vnormals = calcVertexNormals(vertices, indices, self.normals)
            self.__vertNormals[selected] = vnormals
271

272
        return vnormals
273
274


275
276
277
278
279
280
281
282
    @deprecation.deprecated(deprecated_in='1.6.0',
                            removed_in='2.0.0',
                            details='Use bounds instead')
    def getBounds(self):
        """Deprecated - use :meth:`bounds` instead. """
        return self.bounds


283
284
    @property
    def bounds(self):
285
        """Returns a tuple of values which define a minimal bounding box that
286
287
        will contain all of the currently selected vertices in this
        ``Mesh`` instance. The bounding box is arranged like so:
288
289
290
291

            ``((xlow, ylow, zlow), (xhigh, yhigh, zhigh))``
        """

292
293
294
295
        lo = self.__loBounds[self.__selected]
        hi = self.__hiBounds[self.__selected]

        return lo, hi
296
297


298
299
    def addVertices(self, vertices, key, select=True, fixWinding=False):
        """Adds a set of vertices to this ``Mesh``.
300

301
302
        :arg vertices:   A `(n, 3)` array containing ``n`` vertices, compatible
                         with the indices specified in :meth:`__init__`.
303

304
305
306
307
308
309
310
311
        :arg key:        A key for this vertex set.

        :arg select:     If ``True`` (the default), this vertex set is
                         made the currently selected vertex set.

        :arg fixWinding: Defaults to ``False``. If ``True``, the vertex
                         winding order of every triangle is is fixed so they
                         all have outward-facing normal vectors.
312
313
        """

314
315
316
        vertices = np.asarray(vertices)
        lo       = vertices.min(axis=0)
        hi       = vertices.max(axis=0)
317

318
319
320
        self.__vertices[key] = vertices
        self.__loBounds[key] = lo
        self.__hiBounds[key] = hi
321

322
323
        if select:
            self.vertices = key
324

325
326
327
328
329
330
        # indices already fixed?
        if fixWinding and (not self.__fixed):
            indices      = self.indices
            normals      = self.normals
            needsFix     = needsFixing(vertices, indices, normals, lo, hi)
            self.__fixed = True
331

332
333
334
335
336
337
338
339
340
341
342
343
344
345
            # See needsFixing documentation
            if needsFix:

                indices[:, [1, 2]] = indices[:, [2, 1]]

                for k, fn in self.__faceNormals.items():
                    self.__faceNormals[k] = fn * -1


    def addVertexData(self, key, vdata):
        """Adds a vertex-wise data set to the ``Mesh``. It can be retrieved
        by passing the specified ``key`` to the :meth:`getVertexData` method.
        """
        self.__vertexData[key] = vdata
346
347


348
349
350
351
    def getVertexData(self, key):
        """Returns the vertex data for the given ``key`` from the
        internal vertex data cache. If there is no vertex data iwth the
        given key, a ``KeyError`` is raised.
352
353
        """

354
        return self.__vertexData[key]
355
356
357
358


    def clearVertexData(self):
        """Clears the internal vertex data cache - see the
359
        :meth:`addVertexData` and :meth:`getVertexData` methods.
360
        """
361
        self.__vertexData = collections.OrderedDict()
362
363


364
365
366
367
368
369
    @memoize.Instanceify(memoize.memoize)
    def trimesh(self):
        """Reference to a ``trimesh.Trimesh`` object which can be used for
        geometric operations on the mesh.

        If the ``trimesh`` or ``rtree`` libraries are not available, this
370
371
        function returns ``None``, and none of the geometric query methods
        will do anything.
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        """

        # trimesh is an optional dependency - rtree
        # is a depedendency of trimesh which is a
        # wrapper around libspatialindex, without
        # which trimesh can't be used for calculating
        # ray-mesh intersections.
        try:
            import trimesh
            import rtree   # noqa
        except ImportError:
            log.warning('trimesh is not available')
            return None

386
387
388
389
390
391
392
        tm = self.__trimesh.get(self.__selected, None)

        if tm is None:
            tm = trimesh.Trimesh(self.vertices,
                                 self.indices,
                                 process=False,
                                 validate=False)
393

394
            self.__trimesh[self.__selected] = tm
395

396
        return tm
397
398


399
    def rayIntersection(self, origins, directions, vertices=False):
400
401
402
403
404
        """Calculate the intersection between the mesh, and the rays defined by
        ``origins`` and ``directions``.

        :arg origins:    Sequence of ray origins
        :arg directions: Sequence of ray directions
405
406
407
408
409
410
411
412
413
        :returns:        A tuple containing:

                           - A ``(n, 3)`` array containing the coordinates
                             where the mesh was intersected by each of the
                             ``n`` rays.

                           - A ``(n,)`` array containing the indices of the
                             triangles that were intersected by each of the
                             ``n`` rays.
414
415
416
417
418
        """

        trimesh = self.trimesh()

        if trimesh is None:
419
            return np.zeros((0, 3)), np.zeros((0,))
420

421
        tris, rays, locs = trimesh.ray.intersects_id(
422
423
424
            origins,
            directions,
            return_locations=True,
425
            multiple_hits=False)
426

Paul McCarthy's avatar
Paul McCarthy committed
427
        if len(tris) == 0:
428
429
            return np.zeros((0, 3)), np.zeros((0,))

430
431
        # sort by ray. I'm Not sure if this is
        # needed - does trimesh do it for us?
432
        rayIdxs = np.asarray(np.argsort(rays), np.int)
433
434
        locs    = locs[rayIdxs]
        tris    = tris[rayIdxs]
435

436
        return locs, tris
437
438


439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    def nearestVertex(self, points):
        """Identifies the nearest vertex to each of the provided points.

        :arg points: A ``(n, 3)`` array containing the points to query.

        :returns:    A tuple containing:

                      - A ``(n, 3)`` array containing the nearest vertex for
                        for each of the ``n`` input points.

                      - A ``(n,)`` array containing the indices of each vertex.

                      - A ``(n,)`` array containing the distance from each
                        point to the nearest vertex.
        """

        trimesh = self.trimesh()

        if trimesh is None:
458
            return np.zeros((0, 3)), np.zeros((0, )), np.zeros((0, ))
459
460
461
462
463
464
465

        dists, idxs = trimesh.nearest.vertex(points)
        verts       = self.vertices[idxs, :]

        return verts, idxs, dists


466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    def planeIntersection(self,
                          normal,
                          origin,
                          distances=False):
        """Calculate the intersection of this ``TriangleMesh`` with
        the plane defined by ``normal`` and ``origin``.

        :arg normal:    Vector defining the plane orientation

        :arg origin:    Point defining the plane location

        :arg distances: If ``True``, barycentric coordinates for each
                        intersection line vertex are calculated and returned,
                        giving their respective distance from the intersected
                        triangle vertices.

        :returns:       A tuple containing
                          - A ``(m, 2, 3)`` array containing ``m`` vertices:
                            of a set of lines, defining the plane intersection

                          - A ``(m,)`` array containing the indices of the
                            ``m`` triangles that were intersected.

                          - (if ``distances is True``) A ``(m, 2, 3)`` arra
                            containing the barycentric coordinates of each
                            line vertex with respect to its intersected
                            triangle.
        """

        trimesh = self.trimesh()

        if trimesh is None:
            return np.zeros((0, 3)), np.zeros((0, 3))

        import trimesh.intersections as tmint
        import trimesh.triangles     as tmtri

        lines, faces = tmint.mesh_plane(
            trimesh,
            plane_normal=normal,
            plane_origin=origin,
            return_faces=True)

        if not distances:
            return lines, faces

        # Calculate the barycentric coordinates
        # (distance from triangle vertices) for
        # each intersection line

        triangles = self.vertices[self.indices[faces]].repeat(2, axis=0)
        points    = lines.reshape((-1, 3))

        if triangles.size > 0:
            dists = tmtri.points_to_barycentric(triangles, points)
            dists = dists.reshape((-1, 2, 3))
        else:
            dists = np.zeros((0, 2, 3))

        return lines, faces, dists


528
529
530
def calcFaceNormals(vertices, indices):
    """Calculates face normals for the mesh described by ``vertices`` and
    ``indices``.
531

532
533
534
535
    :arg vertices: A ``(n, 3)`` array containing the mesh vertices.
    :arg indices:  A ``(m, 3)`` array containing the mesh triangles.
    :returns:      A ``(m, 3)`` array containing normals for every triangle in
                   the mesh.
536
    """
537

538
539
540
    v0 = vertices[indices[:, 0]]
    v1 = vertices[indices[:, 1]]
    v2 = vertices[indices[:, 2]]
541

542
543
    fnormals = np.cross((v1 - v0), (v2 - v0))
    fnormals = transform.normalise(fnormals)
544

545
    return fnormals
546

547

548
549
550
def calcVertexNormals(vertices, indices, fnormals):
    """Calculates vertex normals for the mesh described by ``vertices``
    and ``indices``.
551

552
553
554
555
556
557
    :arg vertices: A ``(n, 3)`` array containing the mesh vertices.
    :arg indices:  A ``(m, 3)`` array containing the mesh triangles.
    :arg fnormals: A ``(m, 3)`` array containing the face/triangle normals.
    :returns:      A ``(n, 3)`` array containing normals for every vertex in
                   the mesh.
    """
558

559
    vnormals = np.zeros((vertices.shape[0], 3), dtype=np.float)
560

561
562
563
564
565
    # TODO make fast. I can't figure
    # out how to use np.add.at to
    # accumulate the face normals for
    # each vertex.
    for i in range(indices.shape[0]):
566

567
        v0, v1, v2 = indices[i]
568

569
570
571
        vnormals[v0, :] += fnormals[i]
        vnormals[v1, :] += fnormals[i]
        vnormals[v2, :] += fnormals[i]
572

573
574
    # normalise to unit length
    return transform.normalise(vnormals)
575
576


577
578
579
def needsFixing(vertices, indices, fnormals, loBounds, hiBounds):
    """Determines whether the triangle winding order, for the mesh described by
    ``vertices`` and ``indices``, needs to be flipped.
580

581
582
583
    If this function returns ``True``, the given ``indices`` and ``fnormals``
    need to be adjusted so that all face normals are facing outwards from the
    centre of the mesh. The necessary adjustments are as follows::
584

585
586
        indices[:, [1, 2]] = indices[:, [2, 1]]
        fnormals           = fnormals * -1
587

588
589
590
591
592
    :arg vertices: A ``(n, 3)`` array containing the mesh vertices.
    :arg indices:  A ``(m, 3)`` array containing the mesh triangles.
    :arg fnormals: A ``(m, 3)`` array containing the face/triangle normals.
    :arg loBounds: A ``(3, )`` array contaning the low vertex bounds.
    :arg hiBounds: A ``(3, )`` array contaning the high vertex bounds.
593

594
595
    :returns:      ``True`` if the ``indices`` and ``fnormals`` need to be
                   adjusted, ``False`` otherwise.
596
597
    """

598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
    # Define a viewpoint which is
    # far away from the mesh.
    camera = loBounds - (hiBounds - loBounds)

    # Find the nearest vertex
    # to the viewpoint
    dists = np.sqrt(np.sum((vertices - camera) ** 2, axis=1))
    ivert = np.argmin(dists)
    vert  = vertices[ivert]

    # Pick a triangle that
    # this vertex is in and
    # ges its face normal
    itri = np.where(indices == ivert)[0][0]
    n    = fnormals[itri, :]

    # Make sure the angle between the
    # normal, and a vector from the
    # vertex to the camera is positive
    # If it isn't, we need to flip the
    # triangle winding order.
    return np.dot(n, transform.normalise(camera - vert)) < 0
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707


class TriangleMesh(Mesh):
    """Deprecated - use :class:`fsl.data.mesh.Mesh`, or one 'of its sub-classes
    instead.
    """


    @deprecation.deprecated(deprecated_in='1.6.0',
                            removed_in='2.0.0',
                            details='Use fsl.data.mesh.Mesh, or one '
                                    'of its sub-classes instead')
    def __init__(self, data, indices=None, fixWinding=False):

        import fsl.data.vtk as fslvtk

        if isinstance(data, six.string_types):
            name       = op.basename(data)
            dataSource = data
            mesh       = fslvtk.VTKMesh(data, fixWinding=False)
            vertices   = mesh.vertices
            indices    = mesh.indices

        else:
            name       = 'TriangleMesh'
            dataSource = None
            vertices   = data

        Mesh.__init__(self, indices, name=name, dataSource=dataSource)
        self.addVertices(vertices, 'default', fixWinding=fixWinding)


    @deprecation.deprecated(deprecated_in='1.6.0',
                            removed_in='2.0.0',
                            details='Use the Mesh class instead')
    def loadVertexData(self, dataSource, vertexData=None):

        nvertices = self.vertices.shape[0]

        # Currently only white-space delimited
        # text files are supported
        if vertexData is None:
            vertexData = np.loadtxt(dataSource)
            vertexData.reshape(nvertices, -1)

        if vertexData.shape[0] != nvertices:
            raise ValueError('Incompatible size: {}'.format(dataSource))

        self.addVertexData(dataSource, vertexData)

        return vertexData


    @deprecation.deprecated(deprecated_in='1.6.0',
                            removed_in='2.0.0',
                            details='Use the Mesh class instead')
    def getVertexData(self, dataSource):
        try:
            return Mesh.getVertexData(self, dataSource)
        except KeyError:
            return self.loadVertexData(dataSource)


@deprecation.deprecated(deprecated_in='1.6.0',
                        removed_in='2.0.0',
                        details='Use fsl.data.vtk.loadVTKPolydataFile instead')
def loadVTKPolydataFile(*args, **kwargs):
    """Deprecated - use :func:`fsl.data.vtk.loadVTKPolydataFile` instead. """
    import fsl.data.vtk as fslvtk
    return fslvtk.loadVTKPolydataFile(*args, **kwargs)


@deprecation.deprecated(deprecated_in='1.6.0',
                        removed_in='2.0.0',
                        details='Use fsl.data.vtk.getFIRSTPrefix instead')
def getFIRSTPrefix(*args, **kwargs):
    """Deprecated - use :func:`fsl.data.vtk.getFIRSTPrefix` instead. """
    import fsl.data.vtk as fslvtk
    return fslvtk.getFIRSTPrefix(*args, **kwargs)


@deprecation.deprecated(deprecated_in='1.6.0',
                        removed_in='2.0.0',
                        details='Use fsl.data.vtk.findReferenceImage instead')
def findReferenceImage(*args, **kwargs):
    """Deprecated - use :func:`fsl.data.vtk.findReferenceImage` instead. """
    import fsl.data.vtk as fslvtk
    return fslvtk.findReferenceImage(*args, **kwargs)