mesh.py 22.6 KB
Newer Older
1
2
#!/usr/bin/env python
#
3
# mesh.py - The TriangleMesh class.
4
5
6
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
7
"""This module provides the :class:`Mesh` class, which represents a
8
3D model made of triangles.
9

10
See also the following modules:
11

12
  .. autosummary::
13

14
15
16
17
18
19
20
21
22
23
24
25
26
     fsl.data.vtk
     fsl.data.gifti
     fsl.data.freesurfer

A handful of standalone functions are provided in this module, for doing various
things with meshes:

  .. autosummary::
     :nosignatures:

     calcFaceNormals
     calcVertexNormals
     needsFixing
27
28
"""

29

30
import logging
31
32
33
34
import collections

import six
import deprecation
35

36
37
import os.path as op
import numpy   as np
38

39
import fsl.utils.meta      as meta
40
import fsl.utils.notifier  as notifier
41
import fsl.utils.memoize   as memoize
42
import fsl.utils.transform as transform
43
import fsl.data.image      as fslimage
44

45
46
47
48

log = logging.getLogger(__name__)


49
50
class Mesh(notifier.Notifier, meta.Meta):
    """The ``Mesh`` class represents a 3D model. A mesh is defined by a
51
    collection of ``N`` vertices, and ``M`` triangles.  The triangles are
Paul McCarthy's avatar
Paul McCarthy committed
52
    defined by ``(M, 3)`` indices into the list of vertices.
53
54


55
    A ``Mesh`` instance has the following attributes:
56

57

58
59
    ============== ====================================================
    ``name``       A name, typically the file name sans-suffix.
60

61
62
    ``dataSource`` Full path to the mesh file (or ``None`` if there is
                   no file associated with this mesh).
63

64
65
66
67
68
    ``vertices``   A ``(n, 3)`` array containing the currently selected
                   vertices. You can assign  a vertex set key to this
                   attribute to change the selected vertex set.

    ``bounds``     The lower and upper bounds
69

70
71
    ``indices``    A ``(m, 3)`` array containing the vertex indices
                   for ``m`` triangles
72

73
74
    ``normals``    A  ``(m, 3)`` array containing face normals for the
                   triangles
75

76
77
    ``vnormals``   A ``(n, 3)`` array containing vertex normals for the
                   the current vertices.
78
79
    ============== ====================================================

80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    **Vertex sets**


    A ``Mesh`` object can be associated with multiple sets of vertices, but
    only one set of triangles. Vertices can be added via the
    :meth:`addVertices` method. Each vertex set must be associated with a
    unique key - you can then select the current vertex set via the
    :meth:`vertices` property. Most ``Mesh`` methods will raise a ``KeyError``
    if you have not added any vertex sets, or selected a vertex set.


    **Vertex data**


    A ``Mesh`` object can store vertex-wise data. The following methods can be
    used for adding/retrieving vertex data:
97
98
99
100

    .. autosummary::
       :nosignatures:

101
       addVertexData
102
103
       getVertexData
       clearVertexData
104

105

106
    **Notification**
107

108

109
110
111
112
113
114
    The ``Mesh`` class inherits from the :class:`Notifier` class. Whenever the
    ``Mesh`` vertex set is changed, a notification is emitted via the
    ``Notifier`` interface, with a topic of ``'vertices'``. When this occurs,
    the :meth:`vertices`, :meth:`bounds`, :meth:`normals` and :attr:`vnormals`
    properties will all change so that they return data specific to the newly
    selected vertex set.
115
116


117
    **Metadata*
118

119

120
121
    The ``Mesh`` class also inherits from the :class:`Meta` class, so
    any metadata associated with the ``Mesh`` may be added via those methods.
122

123

124
    **Geometric queries**
125

126

127
128
    If the ``trimesh`` library is present, the following methods may be used
    to perform geometric queries on a mesh:
129

130
131
    .. autosummary::
       :nosignatures:
132

133
134
135
136
       rayIntersection
       planeIntersection
       nearestVertex
    """
137
138


139
140
141
142
143
144
    def __init__(self,
                 indices,
                 name='mesh',
                 dataSource=None,
                 vertices=None,
                 fixWinding=False):
145
146
147
148
        """Create a ``Mesh`` instance.

        Before a ``Mesh`` can be used, some vertices must be added via the
        :meth:`addVertices` method.
149

150
151
152
153
154
        :arg indices:    A list of indices into the vertex data, defining the
                         mesh triangles.

        :arg name:       A name for this ``Mesh``.

155
156
157
158
        :arg dataSource: The data source for this ``Mesh``.

        :arg vertices:   Initial vertex set to add - given the key
                         ``'default'``.
159
160
161

        :arg fixWinding: Ignored if ``vertices is None``. Passed through to the
                         :meth:`addVertices` method along with ``vertices``.
162
163
        """

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        self.__name       = name
        self.__dataSource = dataSource
        self.__indices    = np.asarray(indices).reshape((-1, 3))

        # This attribute is used to store
        # the currently selected vertex set,
        # used as a kety into all of the
        # dictionaries below.
        self.__selected = None

        # Flag used to keep track of whether
        # the triangle winding order has been
        # "fixed" - see the addVertices method.
        self.__fixed = False

        # All of these are populated
        # in the addVertices method
        self.__vertices = collections.OrderedDict()
        self.__loBounds = collections.OrderedDict()
        self.__hiBounds = collections.OrderedDict()

        # These get populated on
        # normals/vnormals accesses
        self.__faceNormals = collections.OrderedDict()
        self.__vertNormals = collections.OrderedDict()

        # this gets populated in
        # the addVertexData method
        self.__vertexData  = collections.OrderedDict()

        # this gets populated
        # in the trimesh method
        self.__trimesh = collections.OrderedDict()

198
199
200
201
202
        # Add initial vertex
        # set if provided
        if vertices is not None:
            self.addVertices(vertices)

203

204
205
206
207
208
209
210
211
212
213
214
215
216
    def __repr__(self):
        """Returns a string representation of this ``Mesh`` instance. """
        return '{}({}, {})'.format(type(self).__name__,
                                   self.name,
                                   self.dataSource)


    def __str__(self):
        """Returns a string representation of this ``Mesh`` instance.
        """
        return self.__repr__()


217
218
219
220
221
222
223
224
225
226
    @property
    def name(self):
        """Returns the name of this ``Mesh``. """
        return self.__name


    @property
    def dataSource(self):
        """Returns the data source of this ``Mesh``. """
        return self.__dataSource
227
228


229
230
231
    @property
    def vertices(self):
        """The ``(N, 3)`` vertices of this mesh. """
232
233
234
        return self.__vertices[self.__selected]


235
    @vertices.setter
236
237
238
239
240
241
242
243
244
    def vertices(self, key):
        """Select the current vertex set - a ``KeyError`` is raised
        if no vertex set with the specified ``key`` has been added.
        """

        # Force a key error if
        # the key is invalid
        self.__vertices[key]
        self.__selected = key
245

246
247
        self.notify(topic='vertices')

248
249
250
251

    @property
    def indices(self):
        """The ``(M, 3)`` triangles of this mesh. """
252
        return self.__indices
253
254


255
256
257
258
259
    @property
    def normals(self):
        """A ``(M, 3)`` array containing surface normals for every
        triangle in the mesh, normalised to unit length.
        """
260

261
262
263
264
        selected = self.__selected
        indices  = self.__indices
        vertices = self.__vertices[selected]
        fnormals = self.__faceNormals.get(selected, None)
265

266
267
268
        if fnormals is None:
            fnormals = calcFaceNormals(vertices, indices)
            self.__faceNormals[selected] = fnormals
269

270
        return fnormals
271
272
273
274


    @property
    def vnormals(self):
275
        """A ``(N, 3)`` array containing normals for every vertex
276
277
278
        in the mesh.
        """

279
280
281
282
        selected = self.__selected
        indices  = self.__indices
        vertices = self.__vertices[selected]
        vnormals = self.__vertNormals.get(selected, None)
283

284
285
286
        if vnormals is None:
            vnormals = calcVertexNormals(vertices, indices, self.normals)
            self.__vertNormals[selected] = vnormals
287

288
        return vnormals
289
290


291
292
293
294
295
296
297
298
    @deprecation.deprecated(deprecated_in='1.6.0',
                            removed_in='2.0.0',
                            details='Use bounds instead')
    def getBounds(self):
        """Deprecated - use :meth:`bounds` instead. """
        return self.bounds


299
300
    @property
    def bounds(self):
301
        """Returns a tuple of values which define a minimal bounding box that
302
303
        will contain all of the currently selected vertices in this
        ``Mesh`` instance. The bounding box is arranged like so:
304
305
306
307

            ``((xlow, ylow, zlow), (xhigh, yhigh, zhigh))``
        """

308
309
310
311
        lo = self.__loBounds[self.__selected]
        hi = self.__hiBounds[self.__selected]

        return lo, hi
312
313


314
    def addVertices(self, vertices, key=None, select=True, fixWinding=False):
315
        """Adds a set of vertices to this ``Mesh``.
316

317
318
        :arg vertices:   A `(n, 3)` array containing ``n`` vertices, compatible
                         with the indices specified in :meth:`__init__`.
319

320
321
        :arg key:        A key for this vertex set. If ``None`` defaults to
                         ``'default'``.
322
323
324
325
326
327
328

        :arg select:     If ``True`` (the default), this vertex set is
                         made the currently selected vertex set.

        :arg fixWinding: Defaults to ``False``. If ``True``, the vertex
                         winding order of every triangle is is fixed so they
                         all have outward-facing normal vectors.
329
330
        """

331
332
333
        if key is None:
            key = 'default'

334
335
336
        vertices = np.asarray(vertices)
        lo       = vertices.min(axis=0)
        hi       = vertices.max(axis=0)
337

338
339
340
        self.__vertices[key] = vertices
        self.__loBounds[key] = lo
        self.__hiBounds[key] = hi
341

342
343
        if select:
            self.vertices = key
344

345
346
347
348
349
350
        # indices already fixed?
        if fixWinding and (not self.__fixed):
            indices      = self.indices
            normals      = self.normals
            needsFix     = needsFixing(vertices, indices, normals, lo, hi)
            self.__fixed = True
351

352
353
354
355
356
357
358
359
360
            # See needsFixing documentation
            if needsFix:

                indices[:, [1, 2]] = indices[:, [2, 1]]

                for k, fn in self.__faceNormals.items():
                    self.__faceNormals[k] = fn * -1


361
362
363
364
365
366

    def selectedVertices(self):
        """Returns the key of the currently selected vertex set. """
        return self.__selected


367
368
369
370
    def addVertexData(self, key, vdata):
        """Adds a vertex-wise data set to the ``Mesh``. It can be retrieved
        by passing the specified ``key`` to the :meth:`getVertexData` method.
        """
371
        self.__vertexData[key] = vdata.reshape(vdata.shape[0], -1)
372
373


374
375
376
377
    def getVertexData(self, key):
        """Returns the vertex data for the given ``key`` from the
        internal vertex data cache. If there is no vertex data iwth the
        given key, a ``KeyError`` is raised.
378
379
        """

380
        return self.__vertexData[key]
381
382
383
384


    def clearVertexData(self):
        """Clears the internal vertex data cache - see the
385
        :meth:`addVertexData` and :meth:`getVertexData` methods.
386
        """
387
        self.__vertexData = collections.OrderedDict()
388
389


390
391
392
393
394
395
    @memoize.Instanceify(memoize.memoize)
    def trimesh(self):
        """Reference to a ``trimesh.Trimesh`` object which can be used for
        geometric operations on the mesh.

        If the ``trimesh`` or ``rtree`` libraries are not available, this
396
397
        function returns ``None``, and none of the geometric query methods
        will do anything.
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        """

        # trimesh is an optional dependency - rtree
        # is a depedendency of trimesh which is a
        # wrapper around libspatialindex, without
        # which trimesh can't be used for calculating
        # ray-mesh intersections.
        try:
            import trimesh
            import rtree   # noqa
        except ImportError:
            log.warning('trimesh is not available')
            return None

412
413
414
415
416
417
418
        tm = self.__trimesh.get(self.__selected, None)

        if tm is None:
            tm = trimesh.Trimesh(self.vertices,
                                 self.indices,
                                 process=False,
                                 validate=False)
419

420
            self.__trimesh[self.__selected] = tm
421

422
        return tm
423
424


425
    def rayIntersection(self, origins, directions, vertices=False):
426
427
428
429
430
        """Calculate the intersection between the mesh, and the rays defined by
        ``origins`` and ``directions``.

        :arg origins:    Sequence of ray origins
        :arg directions: Sequence of ray directions
431
432
433
434
435
436
437
438
439
        :returns:        A tuple containing:

                           - A ``(n, 3)`` array containing the coordinates
                             where the mesh was intersected by each of the
                             ``n`` rays.

                           - A ``(n,)`` array containing the indices of the
                             triangles that were intersected by each of the
                             ``n`` rays.
440
441
442
443
444
        """

        trimesh = self.trimesh()

        if trimesh is None:
445
            return np.zeros((0, 3)), np.zeros((0,))
446

447
        tris, rays, locs = trimesh.ray.intersects_id(
448
449
450
            origins,
            directions,
            return_locations=True,
451
            multiple_hits=False)
452

Paul McCarthy's avatar
Paul McCarthy committed
453
        if len(tris) == 0:
454
455
            return np.zeros((0, 3)), np.zeros((0,))

456
457
        # sort by ray. I'm Not sure if this is
        # needed - does trimesh do it for us?
458
        rayIdxs = np.asarray(np.argsort(rays), np.int)
459
460
        locs    = locs[rayIdxs]
        tris    = tris[rayIdxs]
461

462
        return locs, tris
463
464


465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    def nearestVertex(self, points):
        """Identifies the nearest vertex to each of the provided points.

        :arg points: A ``(n, 3)`` array containing the points to query.

        :returns:    A tuple containing:

                      - A ``(n, 3)`` array containing the nearest vertex for
                        for each of the ``n`` input points.

                      - A ``(n,)`` array containing the indices of each vertex.

                      - A ``(n,)`` array containing the distance from each
                        point to the nearest vertex.
        """

        trimesh = self.trimesh()

        if trimesh is None:
484
            return np.zeros((0, 3)), np.zeros((0, )), np.zeros((0, ))
485
486
487
488
489
490
491

        dists, idxs = trimesh.nearest.vertex(points)
        verts       = self.vertices[idxs, :]

        return verts, idxs, dists


492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    def planeIntersection(self,
                          normal,
                          origin,
                          distances=False):
        """Calculate the intersection of this ``TriangleMesh`` with
        the plane defined by ``normal`` and ``origin``.

        :arg normal:    Vector defining the plane orientation

        :arg origin:    Point defining the plane location

        :arg distances: If ``True``, barycentric coordinates for each
                        intersection line vertex are calculated and returned,
                        giving their respective distance from the intersected
                        triangle vertices.

        :returns:       A tuple containing
                          - A ``(m, 2, 3)`` array containing ``m`` vertices:
                            of a set of lines, defining the plane intersection

                          - A ``(m,)`` array containing the indices of the
                            ``m`` triangles that were intersected.

                          - (if ``distances is True``) A ``(m, 2, 3)`` arra
                            containing the barycentric coordinates of each
                            line vertex with respect to its intersected
                            triangle.
        """

        trimesh = self.trimesh()

        if trimesh is None:
            return np.zeros((0, 3)), np.zeros((0, 3))

        import trimesh.intersections as tmint
        import trimesh.triangles     as tmtri

        lines, faces = tmint.mesh_plane(
            trimesh,
            plane_normal=normal,
            plane_origin=origin,
            return_faces=True)

        if not distances:
            return lines, faces

        # Calculate the barycentric coordinates
        # (distance from triangle vertices) for
        # each intersection line

        triangles = self.vertices[self.indices[faces]].repeat(2, axis=0)
        points    = lines.reshape((-1, 3))

        if triangles.size > 0:
            dists = tmtri.points_to_barycentric(triangles, points)
            dists = dists.reshape((-1, 2, 3))
        else:
            dists = np.zeros((0, 2, 3))

        return lines, faces, dists


554
555
556
def calcFaceNormals(vertices, indices):
    """Calculates face normals for the mesh described by ``vertices`` and
    ``indices``.
557

558
559
560
561
    :arg vertices: A ``(n, 3)`` array containing the mesh vertices.
    :arg indices:  A ``(m, 3)`` array containing the mesh triangles.
    :returns:      A ``(m, 3)`` array containing normals for every triangle in
                   the mesh.
562
    """
563

564
565
566
    v0 = vertices[indices[:, 0]]
    v1 = vertices[indices[:, 1]]
    v2 = vertices[indices[:, 2]]
567

568
569
    fnormals = np.cross((v1 - v0), (v2 - v0))
    fnormals = transform.normalise(fnormals)
570

571
    return fnormals
572

573

574
575
576
def calcVertexNormals(vertices, indices, fnormals):
    """Calculates vertex normals for the mesh described by ``vertices``
    and ``indices``.
577

578
579
580
581
582
583
    :arg vertices: A ``(n, 3)`` array containing the mesh vertices.
    :arg indices:  A ``(m, 3)`` array containing the mesh triangles.
    :arg fnormals: A ``(m, 3)`` array containing the face/triangle normals.
    :returns:      A ``(n, 3)`` array containing normals for every vertex in
                   the mesh.
    """
584

585
    vnormals = np.zeros((vertices.shape[0], 3), dtype=np.float)
586

587
588
589
590
591
    # TODO make fast. I can't figure
    # out how to use np.add.at to
    # accumulate the face normals for
    # each vertex.
    for i in range(indices.shape[0]):
592

593
        v0, v1, v2 = indices[i]
594

595
596
597
        vnormals[v0, :] += fnormals[i]
        vnormals[v1, :] += fnormals[i]
        vnormals[v2, :] += fnormals[i]
598

599
600
    # normalise to unit length
    return transform.normalise(vnormals)
601
602


603
604
605
def needsFixing(vertices, indices, fnormals, loBounds, hiBounds):
    """Determines whether the triangle winding order, for the mesh described by
    ``vertices`` and ``indices``, needs to be flipped.
606

607
608
609
    If this function returns ``True``, the given ``indices`` and ``fnormals``
    need to be adjusted so that all face normals are facing outwards from the
    centre of the mesh. The necessary adjustments are as follows::
610

611
612
        indices[:, [1, 2]] = indices[:, [2, 1]]
        fnormals           = fnormals * -1
613

614
615
616
617
618
    :arg vertices: A ``(n, 3)`` array containing the mesh vertices.
    :arg indices:  A ``(m, 3)`` array containing the mesh triangles.
    :arg fnormals: A ``(m, 3)`` array containing the face/triangle normals.
    :arg loBounds: A ``(3, )`` array contaning the low vertex bounds.
    :arg hiBounds: A ``(3, )`` array contaning the high vertex bounds.
619

620
621
    :returns:      ``True`` if the ``indices`` and ``fnormals`` need to be
                   adjusted, ``False`` otherwise.
622
623
    """

624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    # Define a viewpoint which is
    # far away from the mesh.
    camera = loBounds - (hiBounds - loBounds)

    # Find the nearest vertex
    # to the viewpoint
    dists = np.sqrt(np.sum((vertices - camera) ** 2, axis=1))
    ivert = np.argmin(dists)
    vert  = vertices[ivert]

    # Pick a triangle that
    # this vertex is in and
    # ges its face normal
    itri = np.where(indices == ivert)[0][0]
    n    = fnormals[itri, :]

    # Make sure the angle between the
    # normal, and a vector from the
    # vertex to the camera is positive
    # If it isn't, we need to flip the
    # triangle winding order.
    return np.dot(n, transform.normalise(camera - vert)) < 0
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733


class TriangleMesh(Mesh):
    """Deprecated - use :class:`fsl.data.mesh.Mesh`, or one 'of its sub-classes
    instead.
    """


    @deprecation.deprecated(deprecated_in='1.6.0',
                            removed_in='2.0.0',
                            details='Use fsl.data.mesh.Mesh, or one '
                                    'of its sub-classes instead')
    def __init__(self, data, indices=None, fixWinding=False):

        import fsl.data.vtk as fslvtk

        if isinstance(data, six.string_types):
            name       = op.basename(data)
            dataSource = data
            mesh       = fslvtk.VTKMesh(data, fixWinding=False)
            vertices   = mesh.vertices
            indices    = mesh.indices

        else:
            name       = 'TriangleMesh'
            dataSource = None
            vertices   = data

        Mesh.__init__(self, indices, name=name, dataSource=dataSource)
        self.addVertices(vertices, 'default', fixWinding=fixWinding)


    @deprecation.deprecated(deprecated_in='1.6.0',
                            removed_in='2.0.0',
                            details='Use the Mesh class instead')
    def loadVertexData(self, dataSource, vertexData=None):

        nvertices = self.vertices.shape[0]

        # Currently only white-space delimited
        # text files are supported
        if vertexData is None:
            vertexData = np.loadtxt(dataSource)
            vertexData.reshape(nvertices, -1)

        if vertexData.shape[0] != nvertices:
            raise ValueError('Incompatible size: {}'.format(dataSource))

        self.addVertexData(dataSource, vertexData)

        return vertexData


    @deprecation.deprecated(deprecated_in='1.6.0',
                            removed_in='2.0.0',
                            details='Use the Mesh class instead')
    def getVertexData(self, dataSource):
        try:
            return Mesh.getVertexData(self, dataSource)
        except KeyError:
            return self.loadVertexData(dataSource)


@deprecation.deprecated(deprecated_in='1.6.0',
                        removed_in='2.0.0',
                        details='Use fsl.data.vtk.loadVTKPolydataFile instead')
def loadVTKPolydataFile(*args, **kwargs):
    """Deprecated - use :func:`fsl.data.vtk.loadVTKPolydataFile` instead. """
    import fsl.data.vtk as fslvtk
    return fslvtk.loadVTKPolydataFile(*args, **kwargs)


@deprecation.deprecated(deprecated_in='1.6.0',
                        removed_in='2.0.0',
                        details='Use fsl.data.vtk.getFIRSTPrefix instead')
def getFIRSTPrefix(*args, **kwargs):
    """Deprecated - use :func:`fsl.data.vtk.getFIRSTPrefix` instead. """
    import fsl.data.vtk as fslvtk
    return fslvtk.getFIRSTPrefix(*args, **kwargs)


@deprecation.deprecated(deprecated_in='1.6.0',
                        removed_in='2.0.0',
                        details='Use fsl.data.vtk.findReferenceImage instead')
def findReferenceImage(*args, **kwargs):
    """Deprecated - use :func:`fsl.data.vtk.findReferenceImage` instead. """
    import fsl.data.vtk as fslvtk
    return fslvtk.findReferenceImage(*args, **kwargs)