mesh.py 23.4 KB
Newer Older
1
2
#!/usr/bin/env python
#
3
# mesh.py - The TriangleMesh class.
4
5
6
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
7
"""This module provides the :class:`Mesh` class, which represents a
8
3D model made of triangles.
9

10
See also the following modules:
11

12
  .. autosummary::
13

14
15
16
17
     fsl.data.vtk
     fsl.data.gifti
     fsl.data.freesurfer

18
19
A handful of standalone functions are provided in this module, for doing
various things with meshes:
20
21
22
23
24
25
26

  .. autosummary::
     :nosignatures:

     calcFaceNormals
     calcVertexNormals
     needsFixing
27
28
"""

29

30
import logging
31
32
import collections

33
34
import os.path as op
import numpy   as np
35

36
import fsl.utils.meta      as meta
37
import fsl.utils.notifier  as notifier
38
import fsl.utils.transform as transform
39

40
41
42
43

log = logging.getLogger(__name__)


44
45
class Mesh(notifier.Notifier, meta.Meta):
    """The ``Mesh`` class represents a 3D model. A mesh is defined by a
46
    collection of ``N`` vertices, and ``M`` triangles.  The triangles are
Paul McCarthy's avatar
Paul McCarthy committed
47
    defined by ``(M, 3)`` indices into the list of vertices.
48
49


50
    A ``Mesh`` instance has the following attributes:
51

52

53
    ============== ======================================================
54
    ``name``       A name, typically the file name sans-suffix.
55

56
57
    ``dataSource`` Full path to the mesh file (or ``None`` if there is
                   no file associated with this mesh).
58

59
60
    ``nvertices``  The number of vertices in the mesh.

61
62
63
64
65
    ``vertices``   A ``(n, 3)`` array containing the currently selected
                   vertices. You can assign  a vertex set key to this
                   attribute to change the selected vertex set.

    ``bounds``     The lower and upper bounds
66

67
68
    ``indices``    A ``(m, 3)`` array containing the vertex indices
                   for ``m`` triangles
69

70
71
    ``normals``    A  ``(m, 3)`` array containing face normals for the
                   triangles
72

73
74
    ``vnormals``   A ``(n, 3)`` array containing vertex normals for the
                   the current vertices.
75
76
77
78
79

    ``trimesh``    (if the `trimesh <https://github.com/mikedh/trimesh>`_
                   library is present) A ``trimesh.Trimesh`` object which
                   can be used for geometric queries on the mesh.
    ============== ======================================================
80

81

82
83
84
85
86
87
88
89
    **Vertex sets**


    A ``Mesh`` object can be associated with multiple sets of vertices, but
    only one set of triangles. Vertices can be added via the
    :meth:`addVertices` method. Each vertex set must be associated with a
    unique key - you can then select the current vertex set via the
    :meth:`vertices` property. Most ``Mesh`` methods will raise a ``KeyError``
90
91
92
93
94
95
    if you have not added any vertex sets, or selected a vertex set. The
    following methods are available for managing vertex sets:

    .. autosummary::
       :nosignatures:

96
       loadVertices
97
98
       addVertices
       selectedVertices
99
       vertexSets
100
101
102
103
104
105
106


    **Vertex data**


    A ``Mesh`` object can store vertex-wise data. The following methods can be
    used for adding/retrieving vertex data:
107
108
109
110

    .. autosummary::
       :nosignatures:

111
       loadVertexData
112
       addVertexData
113
       getVertexData
114
       vertexDataSets
115
       clearVertexData
116

117

118
    **Notification**
119

120

121
122
123
124
125
126
    The ``Mesh`` class inherits from the :class:`Notifier` class. Whenever the
    ``Mesh`` vertex set is changed, a notification is emitted via the
    ``Notifier`` interface, with a topic of ``'vertices'``. When this occurs,
    the :meth:`vertices`, :meth:`bounds`, :meth:`normals` and :attr:`vnormals`
    properties will all change so that they return data specific to the newly
    selected vertex set.
127
128


Paul McCarthy's avatar
Paul McCarthy committed
129
    **Metadata**
130

131

132
133
    The ``Mesh`` class also inherits from the :class:`Meta` class, so
    any metadata associated with the ``Mesh`` may be added via those methods.
134

135

136
    **Geometric queries**
137

138

139
140
    If the ``trimesh`` library is present, the following methods may be used
    to perform geometric queries on a mesh:
141

142
143
    .. autosummary::
       :nosignatures:
144

145
146
147
148
       rayIntersection
       planeIntersection
       nearestVertex
    """
149
150


151
152
153
154
155
156
157
158
    def __new__(cls, *args, **kwargs):
        """Create a ``Mesh``. We must override ``__new__``, otherwise the
        :class:`Meta` and :class:`Notifier` ``__new__`` methods will not be
        called correctly.
        """
        return super(Mesh, cls).__new__(cls, *args, **kwargs)


159
160
161
162
163
164
    def __init__(self,
                 indices,
                 name='mesh',
                 dataSource=None,
                 vertices=None,
                 fixWinding=False):
165
166
167
168
        """Create a ``Mesh`` instance.

        Before a ``Mesh`` can be used, some vertices must be added via the
        :meth:`addVertices` method.
169

170
171
172
173
174
        :arg indices:    A list of indices into the vertex data, defining the
                         mesh triangles.

        :arg name:       A name for this ``Mesh``.

175
176
177
178
        :arg dataSource: The data source for this ``Mesh``.

        :arg vertices:   Initial vertex set to add - given the key
                         ``'default'``.
179
180
181

        :arg fixWinding: Ignored if ``vertices is None``. Passed through to the
                         :meth:`addVertices` method along with ``vertices``.
182
183
        """

184
185
186
        self.__name       = name
        self.__dataSource = dataSource
        self.__indices    = np.asarray(indices).reshape((-1, 3))
187
        self.__nvertices  = self.__indices.max() + 1
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

        # This attribute is used to store
        # the currently selected vertex set,
        # used as a kety into all of the
        # dictionaries below.
        self.__selected = None

        # Flag used to keep track of whether
        # the triangle winding order has been
        # "fixed" - see the addVertices method.
        self.__fixed = False

        # All of these are populated
        # in the addVertices method
        self.__vertices = collections.OrderedDict()
        self.__loBounds = collections.OrderedDict()
        self.__hiBounds = collections.OrderedDict()

        # These get populated on
        # normals/vnormals accesses
        self.__faceNormals = collections.OrderedDict()
        self.__vertNormals = collections.OrderedDict()

        # this gets populated in
        # the addVertexData method
        self.__vertexData  = collections.OrderedDict()

        # this gets populated
        # in the trimesh method
        self.__trimesh = collections.OrderedDict()

219
220
221
        # Add initial vertex
        # set if provided
        if vertices is not None:
222
            self.addVertices(vertices, fixWinding=fixWinding)
223

224

225
226
227
228
229
230
231
232
233
234
235
236
237
    def __repr__(self):
        """Returns a string representation of this ``Mesh`` instance. """
        return '{}({}, {})'.format(type(self).__name__,
                                   self.name,
                                   self.dataSource)


    def __str__(self):
        """Returns a string representation of this ``Mesh`` instance.
        """
        return self.__repr__()


238
239
240
241
242
243
    @property
    def name(self):
        """Returns the name of this ``Mesh``. """
        return self.__name


Paul McCarthy's avatar
Paul McCarthy committed
244
245
246
247
248
249
    @name.setter
    def name(self, name):
        """Set the name of this ``Mesh``. """
        self.__name = name


250
251
252
253
    @property
    def dataSource(self):
        """Returns the data source of this ``Mesh``. """
        return self.__dataSource
254
255


256
257
258
259
260
261
    @property
    def nvertices(self):
        """Returns the number of vertices in the mesh. """
        return self.__nvertices


262
263
264
    @property
    def vertices(self):
        """The ``(N, 3)`` vertices of this mesh. """
265
266
267
        return self.__vertices[self.__selected]


268
    @vertices.setter
269
270
271
    def vertices(self, key):
        """Select the current vertex set - a ``KeyError`` is raised
        if no vertex set with the specified ``key`` has been added.
272
273
274
275

        When the current vertex set is changed, a notification is emitted
        through the :class:`.Notifier` interface, with the topic
        ``'vertices'``.
276
277
278
279
280
        """

        # Force a key error if
        # the key is invalid
        self.__vertices[key]
281

282
283
284
        if self.__selected != key:
            self.__selected = key
            self.notify(topic='vertices')
285

286
287
288
289

    @property
    def indices(self):
        """The ``(M, 3)`` triangles of this mesh. """
290
        return self.__indices
291
292


293
294
295
296
297
    @property
    def normals(self):
        """A ``(M, 3)`` array containing surface normals for every
        triangle in the mesh, normalised to unit length.
        """
298

299
300
301
302
        selected = self.__selected
        indices  = self.__indices
        vertices = self.__vertices[selected]
        fnormals = self.__faceNormals.get(selected, None)
303

304
305
306
        if fnormals is None:
            fnormals = calcFaceNormals(vertices, indices)
            self.__faceNormals[selected] = fnormals
307

308
        return fnormals
309
310
311
312


    @property
    def vnormals(self):
313
        """A ``(N, 3)`` array containing normals for every vertex
314
315
316
        in the mesh.
        """

317
318
319
320
        selected = self.__selected
        indices  = self.__indices
        vertices = self.__vertices[selected]
        vnormals = self.__vertNormals.get(selected, None)
321

322
323
324
        if vnormals is None:
            vnormals = calcVertexNormals(vertices, indices, self.normals)
            self.__vertNormals[selected] = vnormals
325

326
        return vnormals
327
328


329
330
    @property
    def bounds(self):
331
        """Returns a tuple of values which define a minimal bounding box that
332
333
        will contain all of the currently selected vertices in this
        ``Mesh`` instance. The bounding box is arranged like so:
334
335
336
337

            ``((xlow, ylow, zlow), (xhigh, yhigh, zhigh))``
        """

338
339
340
341
        lo = self.__loBounds[self.__selected]
        hi = self.__hiBounds[self.__selected]

        return lo, hi
342
343


344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    def loadVertices(self, infile, key=None, **kwargs):
        """Loads vertex data from the given ``infile``, and adds it as a vertex
        set with the given ``key``. This implementation supports loading vertex
        data from white-space delimited text files via ``numpy.loadtxt``, but
        sub-classes may override this method to support additional file types.


        :arg infile: File to load data from.

        :arg key:    Key to pass to :meth:`addVertices`. If not provided,
                     set to ``infile`` (converted to an absolute path)

        All of the other arguments are passed through to :meth:`addVertices`.

        :returns:    The loaded vertices.
        """

        infile = op.abspath(infile)

        if key is None:
            key = infile

        vertices = np.loadtxt(infile)

368
        return self.addVertices(vertices, key, **kwargs)
369
370


371
    def addVertices(self, vertices, key=None, select=True, fixWinding=False):
372
        """Adds a set of vertices to this ``Mesh``.
373

374
375
        :arg vertices:   A `(n, 3)` array containing ``n`` vertices, compatible
                         with the indices specified in :meth:`__init__`.
376

377
378
        :arg key:        A key for this vertex set. If ``None`` defaults to
                         ``'default'``.
379
380
381
382
383
384
385

        :arg select:     If ``True`` (the default), this vertex set is
                         made the currently selected vertex set.

        :arg fixWinding: Defaults to ``False``. If ``True``, the vertex
                         winding order of every triangle is is fixed so they
                         all have outward-facing normal vectors.
386
387
388
389
390

        :returns:        The vertices, possibly reshaped

        :raises:         ``ValueError`` if the provided ``vertices`` array
                         has the wrong number of vertices.
391
392
        """

393
394
395
        if key is None:
            key = 'default'

396
397
398
        vertices = np.asarray(vertices)
        lo       = vertices.min(axis=0)
        hi       = vertices.max(axis=0)
399

400
401
402
403
404
405
406
407
408
409
410
411
412
        # Don't allow vertices of
        # different size to be added
        try:
            vertices = vertices.reshape(self.nvertices, 3)

        # reshape raised an error -
        # wrong number of vertices
        except ValueError:
            raise ValueError('{}: invalid number of vertices: '
                             '{} != ({}, 3)'.format(key,
                                                    vertices.shape,
                                                    self.nvertices))

413
414
415
        self.__vertices[key] = vertices
        self.__loBounds[key] = lo
        self.__hiBounds[key] = hi
416

417
418
        if select:
            self.vertices = key
419

420
421
422
423
424
425
        # indices already fixed?
        if fixWinding and (not self.__fixed):
            indices      = self.indices
            normals      = self.normals
            needsFix     = needsFixing(vertices, indices, normals, lo, hi)
            self.__fixed = True
426

427
428
429
430
431
432
433
434
            # See needsFixing documentation
            if needsFix:

                indices[:, [1, 2]] = indices[:, [2, 1]]

                for k, fn in self.__faceNormals.items():
                    self.__faceNormals[k] = fn * -1

435
436
        return vertices

437

438
    def vertexSets(self):
439
440
441
        """Returns a list containing the keys of all vertex sets. """
        return list(self.__vertices.keys())

442
443
444
445
446
447

    def selectedVertices(self):
        """Returns the key of the currently selected vertex set. """
        return self.__selected


448
449
450
451
452
453
454
455
456
457
    def loadVertexData(self, infile, key=None):
        """Loads vertex-wise data from the given ``infile``, and adds it with
        the given ``key``. This implementation supports loading data from
        whitespace-delimited text files via ``numpy.loadtxt``, but sub-classes
        may override this method to support additional file types.

        :arg infile: File to load data from.

        :arg key:    Key to pass to :meth:`addVertexData`. If not provided,
                     set to ``infile`` (converted to an absolute path)
458
459

        :returns:    The loaded vertex data.
460
461
462
463
464
465
466
467
468
        """

        infile = op.abspath(infile)

        if key is None:
            key = infile

        vertexData = np.loadtxt(infile)

469
        return self.addVertexData(key, vertexData)
470
471


472
473
474
    def addVertexData(self, key, vdata):
        """Adds a vertex-wise data set to the ``Mesh``. It can be retrieved
        by passing the specified ``key`` to the :meth:`getVertexData` method.
475
476

        :returns: The vertex data, possibly reshaped.
477
        """
478

479
        nvertices = self.nvertices
480

481
482
483
        if vdata.ndim not in (1, 2) or vdata.shape[0] != nvertices:
            raise ValueError('{}: incompatible vertex data '
                             'shape: {}'.format(key, vdata.shape))
484

485
486
487
488
        vdata                  = vdata.reshape(nvertices, -1)
        self.__vertexData[key] = vdata

        return vdata
489
490


491
492
493
494
    def getVertexData(self, key):
        """Returns the vertex data for the given ``key`` from the
        internal vertex data cache. If there is no vertex data iwth the
        given key, a ``KeyError`` is raised.
495
496
        """

497
        return self.__vertexData[key]
498
499
500
501


    def clearVertexData(self):
        """Clears the internal vertex data cache - see the
502
        :meth:`addVertexData` and :meth:`getVertexData` methods.
503
        """
504
        self.__vertexData = collections.OrderedDict()
505
506


507
508
509
510
511
    def vertexDataSets(self):
        """Returns a list of keys for all loaded vertex data sets. """
        return list(self.__vertexData.keys())


512
    @property
513
514
515
516
517
    def trimesh(self):
        """Reference to a ``trimesh.Trimesh`` object which can be used for
        geometric operations on the mesh.

        If the ``trimesh`` or ``rtree`` libraries are not available, this
518
519
        function returns ``None``, and none of the geometric query methods
        will do anything.
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        """

        # trimesh is an optional dependency - rtree
        # is a depedendency of trimesh which is a
        # wrapper around libspatialindex, without
        # which trimesh can't be used for calculating
        # ray-mesh intersections.
        try:
            import trimesh
            import rtree   # noqa
        except ImportError:
            log.warning('trimesh is not available')
            return None

534
535
536
537
538
539
540
        tm = self.__trimesh.get(self.__selected, None)

        if tm is None:
            tm = trimesh.Trimesh(self.vertices,
                                 self.indices,
                                 process=False,
                                 validate=False)
541

542
            self.__trimesh[self.__selected] = tm
543

544
        return tm
545
546


547
    def rayIntersection(self, origins, directions, vertices=False):
548
549
550
551
552
        """Calculate the intersection between the mesh, and the rays defined by
        ``origins`` and ``directions``.

        :arg origins:    Sequence of ray origins
        :arg directions: Sequence of ray directions
553
554
555
556
557
558
559
560
561
        :returns:        A tuple containing:

                           - A ``(n, 3)`` array containing the coordinates
                             where the mesh was intersected by each of the
                             ``n`` rays.

                           - A ``(n,)`` array containing the indices of the
                             triangles that were intersected by each of the
                             ``n`` rays.
562
563
        """

564
        trimesh = self.trimesh
565
566

        if trimesh is None:
567
            return np.zeros((0, 3)), np.zeros((0,))
568

569
        tris, rays, locs = trimesh.ray.intersects_id(
570
571
572
            origins,
            directions,
            return_locations=True,
573
            multiple_hits=False)
574

Paul McCarthy's avatar
Paul McCarthy committed
575
        if len(tris) == 0:
576
577
            return np.zeros((0, 3)), np.zeros((0,))

578
579
        # sort by ray. I'm Not sure if this is
        # needed - does trimesh do it for us?
580
        rayIdxs = np.asarray(np.argsort(rays), np.int)
581
582
        locs    = locs[rayIdxs]
        tris    = tris[rayIdxs]
583

584
        return locs, tris
585
586


587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
    def nearestVertex(self, points):
        """Identifies the nearest vertex to each of the provided points.

        :arg points: A ``(n, 3)`` array containing the points to query.

        :returns:    A tuple containing:

                      - A ``(n, 3)`` array containing the nearest vertex for
                        for each of the ``n`` input points.

                      - A ``(n,)`` array containing the indices of each vertex.

                      - A ``(n,)`` array containing the distance from each
                        point to the nearest vertex.
        """

603
        trimesh = self.trimesh
604
605

        if trimesh is None:
606
            return np.zeros((0, 3)), np.zeros((0, )), np.zeros((0, ))
607
608
609
610
611
612
613

        dists, idxs = trimesh.nearest.vertex(points)
        verts       = self.vertices[idxs, :]

        return verts, idxs, dists


614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
    def planeIntersection(self,
                          normal,
                          origin,
                          distances=False):
        """Calculate the intersection of this ``TriangleMesh`` with
        the plane defined by ``normal`` and ``origin``.

        :arg normal:    Vector defining the plane orientation

        :arg origin:    Point defining the plane location

        :arg distances: If ``True``, barycentric coordinates for each
                        intersection line vertex are calculated and returned,
                        giving their respective distance from the intersected
                        triangle vertices.

        :returns:       A tuple containing
Paul McCarthy's avatar
Paul McCarthy committed
631

632
633
634
635
636
637
                          - A ``(m, 2, 3)`` array containing ``m`` vertices:
                            of a set of lines, defining the plane intersection

                          - A ``(m,)`` array containing the indices of the
                            ``m`` triangles that were intersected.

Paul McCarthy's avatar
Paul McCarthy committed
638
                          - (if ``distances is True``) A ``(m, 2, 3)`` array
639
640
641
642
643
                            containing the barycentric coordinates of each
                            line vertex with respect to its intersected
                            triangle.
        """

644
        trimesh = self.trimesh
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676

        if trimesh is None:
            return np.zeros((0, 3)), np.zeros((0, 3))

        import trimesh.intersections as tmint
        import trimesh.triangles     as tmtri

        lines, faces = tmint.mesh_plane(
            trimesh,
            plane_normal=normal,
            plane_origin=origin,
            return_faces=True)

        if not distances:
            return lines, faces

        # Calculate the barycentric coordinates
        # (distance from triangle vertices) for
        # each intersection line

        triangles = self.vertices[self.indices[faces]].repeat(2, axis=0)
        points    = lines.reshape((-1, 3))

        if triangles.size > 0:
            dists = tmtri.points_to_barycentric(triangles, points)
            dists = dists.reshape((-1, 2, 3))
        else:
            dists = np.zeros((0, 2, 3))

        return lines, faces, dists


677
678
679
def calcFaceNormals(vertices, indices):
    """Calculates face normals for the mesh described by ``vertices`` and
    ``indices``.
680

681
682
683
684
    :arg vertices: A ``(n, 3)`` array containing the mesh vertices.
    :arg indices:  A ``(m, 3)`` array containing the mesh triangles.
    :returns:      A ``(m, 3)`` array containing normals for every triangle in
                   the mesh.
685
    """
686

687
688
689
    v0 = vertices[indices[:, 0]]
    v1 = vertices[indices[:, 1]]
    v2 = vertices[indices[:, 2]]
690

691
692
    fnormals = np.cross((v1 - v0), (v2 - v0))
    fnormals = transform.normalise(fnormals)
693

694
    return fnormals
695

696

697
698
699
def calcVertexNormals(vertices, indices, fnormals):
    """Calculates vertex normals for the mesh described by ``vertices``
    and ``indices``.
700

701
702
703
704
705
706
    :arg vertices: A ``(n, 3)`` array containing the mesh vertices.
    :arg indices:  A ``(m, 3)`` array containing the mesh triangles.
    :arg fnormals: A ``(m, 3)`` array containing the face/triangle normals.
    :returns:      A ``(n, 3)`` array containing normals for every vertex in
                   the mesh.
    """
707

708
    vnormals = np.zeros((vertices.shape[0], 3), dtype=np.float)
709

710
711
712
713
714
    # TODO make fast. I can't figure
    # out how to use np.add.at to
    # accumulate the face normals for
    # each vertex.
    for i in range(indices.shape[0]):
715

716
        v0, v1, v2 = indices[i]
717

718
719
720
        vnormals[v0, :] += fnormals[i]
        vnormals[v1, :] += fnormals[i]
        vnormals[v2, :] += fnormals[i]
721

722
723
    # normalise to unit length
    return transform.normalise(vnormals)
724
725


726
727
728
def needsFixing(vertices, indices, fnormals, loBounds, hiBounds):
    """Determines whether the triangle winding order, for the mesh described by
    ``vertices`` and ``indices``, needs to be flipped.
729

730
731
732
    If this function returns ``True``, the given ``indices`` and ``fnormals``
    need to be adjusted so that all face normals are facing outwards from the
    centre of the mesh. The necessary adjustments are as follows::
733

734
735
        indices[:, [1, 2]] = indices[:, [2, 1]]
        fnormals           = fnormals * -1
736

737
738
739
740
741
    :arg vertices: A ``(n, 3)`` array containing the mesh vertices.
    :arg indices:  A ``(m, 3)`` array containing the mesh triangles.
    :arg fnormals: A ``(m, 3)`` array containing the face/triangle normals.
    :arg loBounds: A ``(3, )`` array contaning the low vertex bounds.
    :arg hiBounds: A ``(3, )`` array contaning the high vertex bounds.
742

743
744
    :returns:      ``True`` if the ``indices`` and ``fnormals`` need to be
                   adjusted, ``False`` otherwise.
745
746
    """

747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
    # Define a viewpoint which is
    # far away from the mesh.
    camera = loBounds - (hiBounds - loBounds)

    # Find the nearest vertex
    # to the viewpoint
    dists = np.sqrt(np.sum((vertices - camera) ** 2, axis=1))
    ivert = np.argmin(dists)
    vert  = vertices[ivert]

    # Pick a triangle that
    # this vertex is in and
    # ges its face normal
    itri = np.where(indices == ivert)[0][0]
    n    = fnormals[itri, :]

    # Make sure the angle between the
    # normal, and a vector from the
    # vertex to the camera is positive
    # If it isn't, we need to flip the
    # triangle winding order.
    return np.dot(n, transform.normalise(camera - vert)) < 0