Commit 9044c20a authored by Paul McCarthy's avatar Paul McCarthy 🚵
Browse files

TEST: Test new any/all operations

parent 06c38d68
......@@ -13,10 +13,11 @@ import pytest
import pyparsing as pp
import pandas as pd
import numpy as np
import funpack.expression as expression
from . import gen_DataTableFromDataFrame
from . import gen_DataTable, gen_DataTableFromDataFrame
_test_data = """
......@@ -36,14 +37,20 @@ _test_cols = {
_expr_tests = [
( 'v10 == 1', [1, 0, 0, 0, 0]),
( 'v10 != 1', [0, 1, 1, 1, 1]),
( 'v10 > 7', [0, 0, 0, 1, 1]),
( 'v10 >= 7', [0, 0, 1, 1, 1]),
( 'v10 < 7', [1, 1, 0, 0, 0]),
( 'v10 <= 7', [1, 1, 1, 0, 0]),
('~(v10 == 1)', [0, 1, 1, 1, 1]),
('~(v10 != 1)', [1, 0, 0, 0, 0]),
( 'v10 == 1', [1, 0, 0, 0, 0]),
( 'v10 != 1', [0, 1, 1, 1, 1]),
( 'v10 > 7', [0, 0, 0, 1, 1]),
( 'v10 >= 7', [0, 0, 1, 1, 1]),
( 'v10 < 7', [1, 1, 0, 0, 0]),
( 'v10 <= 7', [1, 1, 1, 0, 0]),
( '~(v10 == 1)', [0, 1, 1, 1, 1]),
( '~(v10 != 1)', [1, 0, 0, 0, 0]),
( 'all(v10 != 1)', [0, 1, 1, 1, 1]),
( 'any(v10 != 1)', [0, 1, 1, 1, 1]),
('~(all(v10 != 1))', [1, 0, 0, 0, 0]),
('~(any(v10 != 1))', [1, 0, 0, 0, 0]),
('all(~(v10 != 1))', [1, 0, 0, 0, 0]),
('any(~(v10 != 1))', [1, 0, 0, 0, 0]),
('v20 == na', [0, 0, 1, 0, 1]),
('v20 != na', [1, 1, 0, 1, 0]),
......@@ -89,7 +96,6 @@ def test_Expression():
e = expression.Expression(expr)
assert sorted(e.variables) == sorted(vine(expr))
coldata = {vid : col for vid, col in _test_cols.items()}
result = e.evaluate(dt, coldata)
......@@ -97,7 +103,58 @@ def test_Expression():
assert all([bool(r) == bool(e) for r, e in zip(result, expected)])
def test_calculaetExpressionEvaluationOrder():
def test_Expression_multiple_columns():
data = np.random.randint(1, 100, (100, 6))
data[:, 0] = np.arange(1, 101)
cols = ['eid', '1-0.0', '1-1.0', '1-2.0', '2-0.0', '2-1.0']
df = pd.DataFrame({c : d for c, d in zip(cols, data.T)}).set_index('eid')
dtable = gen_DataTableFromDataFrame(df)
# one column
e = expression.Expression('v1 > 50')
result = e.evaluate(dtable, {1 : '1-0.0'})
assert np.all(result == (data[:, 1] > 50))
# multiple columns
e = expression.Expression('v1 <= 75')
result = e.evaluate(dtable, {1 : ['1-0.0', '1-1.0']})
assert np.all(result == (data[:, 1:3] <= 75))
e = expression.Expression('v1 != 23')
result = e.evaluate(dtable, {1 : ['1-0.0', '1-1.0', '1-2.0']})
assert np.all(result == (data[:, 1:4] != 23))
# multi-column multi-var - number
# of columns per var must match
e = expression.Expression('v1 > 50 && v2 < 50')
result = e.evaluate(dtable, {1 : ['1-0.0', '1-1.0'],
2 : ['2-0.0', '2-1.0']})
exp = (data[:, 1:3] > 50) & (data[:, 4:] < 50)
assert np.all(result == exp)
# use any/all to collapse across columns
e = expression.Expression('all(v1 > 50 && v2 < 50)')
result = e.evaluate(dtable, {1 : ['1-0.0', '1-1.0'],
2 : ['2-0.0', '2-1.0']})
exp = ((data[:, 1:3] > 50) & (data[:, 4:] < 50)).all(axis=1)
assert np.all(result == exp)
e = expression.Expression('any(v1 > 50 && v2 < 50)')
result = e.evaluate(dtable, {1 : ['1-0.0', '1-1.0'],
2 : ['2-0.0', '2-1.0']})
exp = ((data[:, 1:3] > 50) & (data[:, 4:] < 50)).any(axis=1)
assert np.all(result == exp)
e = expression.Expression('any(v1 > 50) && all(v2 < 50)')
result = e.evaluate(dtable, {1 : ['1-0.0', '1-1.0'],
2 : ['2-0.0', '2-1.0']})
exp = (data[:, 1:3] > 50).any(axis=1) & (data[:, 4:] < 50).all(axis=1)
assert np.all(result == exp)
def test_calculateExpressionEvaluationOrder():
def makexprs(exprstrs):
exprs = []
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment