diff --git a/fsl/data/mesh.py b/fsl/data/mesh.py
index 92bf0f248bcf758e65c69629ec00843dcc0f8fc7..8fdab521ddabce40e98448ca1f6342ce546e49dc 100644
--- a/fsl/data/mesh.py
+++ b/fsl/data/mesh.py
@@ -263,67 +263,41 @@ class TriangleMesh(object):
         return self.__trimesh
 
 
-    def rayIntersection(self, origins, directions, sort=False, vertices=False):
+    def rayIntersection(self, origins, directions, vertices=False):
         """Calculate the intersection between the mesh, and the rays defined by
         ``origins`` and ``directions``.
 
         :arg origins:    Sequence of ray origins
         :arg directions: Sequence of ray directions
-        :arg sort:       By default, the calculated intersection coordinates
-                         are not sorted. If ``sort`` is set to ``True`` will
-                         be sorted according to increasing distance from the
-                         ray origin.
-        :arg vertices:   By default, the returned coordinates are where the
-                         ray intersects with triangles of the mesh. If
-                         ``vertices`` is set to ``True``, the returned
-                         coordinates will be the vertices that were nearest
-                         to the ray intersection points.
-
-        :returns:        A list-of-lists, one for each input ray, containing
-                         the points where the ray intersected the mesh.
+        :returns:        A tuple containing:
+
+                           - A ``(n, 3)`` array containing the coordinates
+                             where the mesh was intersected by each of the
+                             ``n`` rays.
+
+                           - A ``(n,)`` array containing the indices of the
+                             triangles that were intersected by each of the
+                             ``n`` rays.
         """
 
-        hits    = [[] for o in origins]
         trimesh = self.trimesh()
 
         if trimesh is None:
-            return hits
+            return np.zeros((0, 3)), np.zeros((0,))
 
-        tris, rays, locs = self.trimesh().ray.intersects_id(
+        tris, rays, locs = trimesh.ray.intersects_id(
             origins,
             directions,
             return_locations=True,
-            mulltiple_hits=False)
-
-        # group by input ray
-        dists = [[] for o in origins]
-
-        for tri, ray, loc in zip(tris, rays, locs):
-            hits[ray].append(loc)
-
-            if sort:
-                dists[ray].append(transform.veclength(loc - origins[ray]))
-
-        # returned locations are not
-        # sorted, so we need to sort
-        # them by distance from the
-        # origin
-        if sort:
-            for i in range(len(hits)):
-
-                rayHits  = hits[ i]
-                rayDists = dists[i]
-
-                if len(rayHits) > 0:
-                    rayDists, rayHits = zip(*sorted(zip(rayDists, rayHits)))
-                    hits[ i] = rayHits
-                    dists[i] = rayDists
+            multiple_hits=False)
 
-        # TODO
-        if vertices:
-            pass
+        # sort by ray. I'm Not sure if this is
+        # needed - does trimesh do it for us?
+        rayIdxs = np.argsort(rays)
+        locs    = locs[rayIdxs]
+        tris    = tris[rayIdxs]
 
-        return hits
+        return locs, tris
 
 
     def getBounds(self):