diff --git a/fsl/data/mesh.py b/fsl/data/mesh.py index 5e9856dfa145c95d3c4d2188963deda06fae292b..92bf0f248bcf758e65c69629ec00843dcc0f8fc7 100644 --- a/fsl/data/mesh.py +++ b/fsl/data/mesh.py @@ -27,8 +27,7 @@ import numpy as np import six -import trimesh - +import fsl.utils.memoize as memoize import fsl.utils.transform as transform from . import image as fslimage @@ -123,11 +122,6 @@ class TriangleMesh(object): if fixWinding: self.__fixWindingOrder() - self.__trimesh = trimesh.Trimesh(self.__vertices, - self.__indices, - process=False, - validate=False) - def __repr__(self): """Returns a string representation of this ``TriangleMesh`` instance. @@ -142,26 +136,6 @@ class TriangleMesh(object): return self.__repr__() - @property - def vertices(self): - """The ``(N, 3)`` vertices of this mesh. """ - return self.__vertices - - - @property - def indices(self): - """The ``(M, 3)`` triangles of this mesh. """ - return self.__indices - - - @property - def trimesh(self): - """Reference to a ``trimesh.Trimesh`` object which can be used for - geometric operations on the mesh. - """ - return self.__trimesh - - def __fixWindingOrder(self): """Called by :meth:`__init__` if ``fixWinding is True``. Fixes the mesh triangle winding order so that all face normals are facing @@ -195,6 +169,18 @@ class TriangleMesh(object): self.__faceNormals *= -1 + @property + def vertices(self): + """The ``(N, 3)`` vertices of this mesh. """ + return self.__vertices + + + @property + def indices(self): + """The ``(M, 3)`` triangles of this mesh. """ + return self.__indices + + @property def normals(self): """A ``(M, 3)`` array containing surface normals for every @@ -245,6 +231,101 @@ class TriangleMesh(object): return self.__vertNormals + @memoize.Instanceify(memoize.memoize) + def trimesh(self): + """Reference to a ``trimesh.Trimesh`` object which can be used for + geometric operations on the mesh. + + If the ``trimesh`` or ``rtree`` libraries are not available, this + function returns ``None`` + """ + + # trimesh is an optional dependency - rtree + # is a depedendency of trimesh which is a + # wrapper around libspatialindex, without + # which trimesh can't be used for calculating + # ray-mesh intersections. + try: + import trimesh + import rtree # noqa + except ImportError: + log.warning('trimesh is not available') + return None + + if hasattr(self, '__trimesh'): + return self.__trimesh + + self.__trimesh = trimesh.Trimesh(self.__vertices, + self.__indices, + process=False, + validate=False) + + return self.__trimesh + + + def rayIntersection(self, origins, directions, sort=False, vertices=False): + """Calculate the intersection between the mesh, and the rays defined by + ``origins`` and ``directions``. + + :arg origins: Sequence of ray origins + :arg directions: Sequence of ray directions + :arg sort: By default, the calculated intersection coordinates + are not sorted. If ``sort`` is set to ``True`` will + be sorted according to increasing distance from the + ray origin. + :arg vertices: By default, the returned coordinates are where the + ray intersects with triangles of the mesh. If + ``vertices`` is set to ``True``, the returned + coordinates will be the vertices that were nearest + to the ray intersection points. + + :returns: A list-of-lists, one for each input ray, containing + the points where the ray intersected the mesh. + """ + + hits = [[] for o in origins] + trimesh = self.trimesh() + + if trimesh is None: + return hits + + tris, rays, locs = self.trimesh().ray.intersects_id( + origins, + directions, + return_locations=True, + mulltiple_hits=False) + + # group by input ray + dists = [[] for o in origins] + + for tri, ray, loc in zip(tris, rays, locs): + hits[ray].append(loc) + + if sort: + dists[ray].append(transform.veclength(loc - origins[ray])) + + # returned locations are not + # sorted, so we need to sort + # them by distance from the + # origin + if sort: + for i in range(len(hits)): + + rayHits = hits[ i] + rayDists = dists[i] + + if len(rayHits) > 0: + rayDists, rayHits = zip(*sorted(zip(rayDists, rayHits))) + hits[ i] = rayHits + dists[i] = rayDists + + # TODO + if vertices: + pass + + return hits + + def getBounds(self): """Returns a tuple of values which define a minimal bounding box that will contain all vertices in this ``TriangleMesh`` instance. The