From 61e0ccd398e32d7d5e1590ffd037fbf3497599ab Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Sat, 29 Jul 2017 20:03:59 +0100
Subject: [PATCH] Updated veclength/normalise tests

---
 tests/test_transform.py | 43 +++++++++++++++++++++++++++++++++++------
 1 file changed, 37 insertions(+), 6 deletions(-)

diff --git a/tests/test_transform.py b/tests/test_transform.py
index 21b7db564..609df61df 100644
--- a/tests/test_transform.py
+++ b/tests/test_transform.py
@@ -8,6 +8,7 @@
 
 from __future__ import division
 
+import                 random
 import                 glob
 import os.path      as op
 import itertools    as it
@@ -400,7 +401,7 @@ def test_sformToFlirtMatrix():
 
 def test_normalise(seed):
 
-    vectors = -100 + 200 * np.random.random((50, 3))
+    vectors = -100 + 200 * np.random.random((200, 3))
 
     def parallel(v1, v2):
         v1 = v1 / transform.veclength(v1)
@@ -409,25 +410,55 @@ def test_normalise(seed):
         return np.isclose(np.dot(v1, v2), 1)
 
     for v in vectors:
-        vn = transform.normalise(v)
-        vl = transform.veclength(vn)
+
+        vtype = random.choice((list, tuple, np.array))
+        v     = vtype(v)
+        vn    = transform.normalise(v)
+        vl    = transform.veclength(vn)
 
         assert np.isclose(vl, 1.0)
         assert parallel(v, vn)
 
+    # normalise should also be able
+    # to do multiple vectors at once
+    results = transform.normalise(vectors)
+    lengths = transform.veclength(results)
+    pars    = np.zeros(200)
+    for i in range(200):
+
+        v = vectors[i]
+        r = results[i]
+
+        pars[i] = parallel(v, r)
+
+    assert np.all(np.isclose(lengths, 1))
+    assert np.all(pars)
+
 
 def test_veclength(seed):
 
     def l(v):
-        x, y, z = v
-        l       = x * x + y * y + z * z
+        v = np.array(v, copy=False).reshape((-1, 3))
+        x = v[:, 0]
+        y = v[:, 1]
+        z = v[:, 2]
+        l = x * x + y * y + z * z
         return np.sqrt(l)
 
-    vectors = -100 + 200 * np.random.random((50, 3))
+    vectors = -100 + 200 * np.random.random((200, 3))
 
     for v in vectors:
+
+        vtype = random.choice((list, tuple, np.array))
+        v     = vtype(v)
+
         assert np.isclose(transform.veclength(v), l(v))
 
+    # Multiple vectors in parallel
+    result   = transform.veclength(vectors)
+    expected = l(vectors)
+    assert np.all(np.isclose(result, expected))
+
 
 def test_transformNormal(seed):
 
-- 
GitLab