diff --git a/fsl/data/mesh.py b/fsl/data/mesh.py index 92bf0f248bcf758e65c69629ec00843dcc0f8fc7..8fdab521ddabce40e98448ca1f6342ce546e49dc 100644 --- a/fsl/data/mesh.py +++ b/fsl/data/mesh.py @@ -263,67 +263,41 @@ class TriangleMesh(object): return self.__trimesh - def rayIntersection(self, origins, directions, sort=False, vertices=False): + def rayIntersection(self, origins, directions, vertices=False): """Calculate the intersection between the mesh, and the rays defined by ``origins`` and ``directions``. :arg origins: Sequence of ray origins :arg directions: Sequence of ray directions - :arg sort: By default, the calculated intersection coordinates - are not sorted. If ``sort`` is set to ``True`` will - be sorted according to increasing distance from the - ray origin. - :arg vertices: By default, the returned coordinates are where the - ray intersects with triangles of the mesh. If - ``vertices`` is set to ``True``, the returned - coordinates will be the vertices that were nearest - to the ray intersection points. - - :returns: A list-of-lists, one for each input ray, containing - the points where the ray intersected the mesh. + :returns: A tuple containing: + + - A ``(n, 3)`` array containing the coordinates + where the mesh was intersected by each of the + ``n`` rays. + + - A ``(n,)`` array containing the indices of the + triangles that were intersected by each of the + ``n`` rays. """ - hits = [[] for o in origins] trimesh = self.trimesh() if trimesh is None: - return hits + return np.zeros((0, 3)), np.zeros((0,)) - tris, rays, locs = self.trimesh().ray.intersects_id( + tris, rays, locs = trimesh.ray.intersects_id( origins, directions, return_locations=True, - mulltiple_hits=False) - - # group by input ray - dists = [[] for o in origins] - - for tri, ray, loc in zip(tris, rays, locs): - hits[ray].append(loc) - - if sort: - dists[ray].append(transform.veclength(loc - origins[ray])) - - # returned locations are not - # sorted, so we need to sort - # them by distance from the - # origin - if sort: - for i in range(len(hits)): - - rayHits = hits[ i] - rayDists = dists[i] - - if len(rayHits) > 0: - rayDists, rayHits = zip(*sorted(zip(rayDists, rayHits))) - hits[ i] = rayHits - dists[i] = rayDists + multiple_hits=False) - # TODO - if vertices: - pass + # sort by ray. I'm Not sure if this is + # needed - does trimesh do it for us? + rayIdxs = np.argsort(rays) + locs = locs[rayIdxs] + tris = tris[rayIdxs] - return hits + return locs, tris def getBounds(self):