Commit 91f0507f authored by Paul McCarthy's avatar Paul McCarthy 🚵
Browse files

TEST: Test non-numeric conditions

parent 8d8b5cd8
......@@ -7,6 +7,7 @@
import io
import textwrap as tw
import pytest
......@@ -20,62 +21,62 @@ import funpack.expression as expression
from . import gen_DataTable, gen_DataTableFromDataFrame
_test_data = """
index 10 20 30
1 1 2 3
2 4 5 6
3 7 9
4 10 11 12
5 13 15
""".strip().replace(' ', '\t')
_test_cols = {
10 : '10',
20 : '20',
30 : '30',
}
_expr_tests = [
( 'v10 == 1', [1, 0, 0, 0, 0]),
( 'v10 != 1', [0, 1, 1, 1, 1]),
( 'v10 > 7', [0, 0, 0, 1, 1]),
( 'v10 >= 7', [0, 0, 1, 1, 1]),
( 'v10 < 7', [1, 1, 0, 0, 0]),
( 'v10 <= 7', [1, 1, 1, 0, 0]),
( '~(v10 == 1)', [0, 1, 1, 1, 1]),
( '~(v10 != 1)', [1, 0, 0, 0, 0]),
( 'all(v10 != 1)', [0, 1, 1, 1, 1]),
( 'any(v10 != 1)', [0, 1, 1, 1, 1]),
('~(all(v10 != 1))', [1, 0, 0, 0, 0]),
('~(any(v10 != 1))', [1, 0, 0, 0, 0]),
('all(~(v10 != 1))', [1, 0, 0, 0, 0]),
('any(~(v10 != 1))', [1, 0, 0, 0, 0]),
('v20 == na', [0, 0, 1, 0, 1]),
('v20 != na', [1, 1, 0, 1, 0]),
( 'v10 >= 7 && v30 < 10', [0, 0, 1, 0, 0]),
('~(v10 >= 7 && v30 < 10)', [1, 1, 0, 1, 1]),
( 'v10 >= 4 && v30 < 10 || v20 != na', [1, 1, 1, 1, 0]),
('(v10 >= 4 && v30 < 10) || v20 != na', [1, 1, 1, 1, 0]),
( 'v10 >= 4 && (v30 < 10 || v20 != na)', [0, 1, 1, 1, 0]),
# bad
('10 == 1', 'error'),
('10 ==', 'error'),
('v10', 'error'),
('v10 ==', 'error'),
('v10 1', 'error'),
('v10 == 1 &&', 'error'),
('v10 == 1 && 24', 'error'),
('abcde', 'error'),
]
def test_Expression():
_test_data = tw.dedent("""
index 10 20 30
1 1 2 3
2 4 5 6
3 7 9
4 10 11 12
5 13 15
""").strip().replace(' ', '\t')
_test_cols = {
10 : '10',
20 : '20',
30 : '30',
}
_expr_tests = [
( 'v10 == 1', [1, 0, 0, 0, 0]),
( 'v10 != 1', [0, 1, 1, 1, 1]),
( 'v10 > 7', [0, 0, 0, 1, 1]),
( 'v10 >= 7', [0, 0, 1, 1, 1]),
( 'v10 < 7', [1, 1, 0, 0, 0]),
( 'v10 <= 7', [1, 1, 1, 0, 0]),
( '~(v10 == 1)', [0, 1, 1, 1, 1]),
( '~(v10 != 1)', [1, 0, 0, 0, 0]),
( 'all(v10 != 1)', [0, 1, 1, 1, 1]),
( 'any(v10 != 1)', [0, 1, 1, 1, 1]),
('~(all(v10 != 1))', [1, 0, 0, 0, 0]),
('~(any(v10 != 1))', [1, 0, 0, 0, 0]),
('all(~(v10 != 1))', [1, 0, 0, 0, 0]),
('v20 == na', [0, 0, 1, 0, 1]),
('v20 != na', [1, 1, 0, 1, 0]),
( 'v10 >= 7 && v30 < 10', [0, 0, 1, 0, 0]),
('~(v10 >= 7 && v30 < 10)', [1, 1, 0, 1, 1]),
( 'v10 >= 4 && v30 < 10 || v20 != na', [1, 1, 1, 1, 0]),
('(v10 >= 4 && v30 < 10) || v20 != na', [1, 1, 1, 1, 0]),
( 'v10 >= 4 && (v30 < 10 || v20 != na)', [0, 1, 1, 1, 0]),
# bad
('10 == 1', 'error'),
('10 ==', 'error'),
('v10', 'error'),
('v10 ==', 'error'),
('v10 1', 'error'),
('v10 == 1 &&', 'error'),
('v10 == 1 && 24', 'error'),
('abcde', 'error'),
]
def vine(e):
vs = []
if 'v10' in e: vs.append(10)
......@@ -96,16 +97,16 @@ def test_Expression():
e = expression.Expression(expr)
assert sorted(e.variables) == sorted(vine(expr))
coldata = {vid : col for vid, col in _test_cols.items()}
result = e.evaluate(dt, coldata)
result = e.evaluate(dt, _test_cols)
assert len(result) == len(expected)
assert all([bool(r) == bool(e) for r, e in zip(result, expected)])
def test_Expression_multiple_columns():
data = np.random.randint(1, 100, (100, 6))
data[:, 0] = np.arange(1, 101)
data = np.random.randint(1, 100, (20, 6))
data[:, 0] = np.arange(1, 21)
cols = ['eid', '1-0.0', '1-1.0', '1-2.0', '2-0.0', '2-1.0']
df = pd.DataFrame({c : d for c, d in zip(cols, data.T)}).set_index('eid')
dtable = gen_DataTableFromDataFrame(df)
......@@ -113,7 +114,7 @@ def test_Expression_multiple_columns():
# one column
e = expression.Expression('v1 > 50')
result = e.evaluate(dtable, {1 : '1-0.0'})
assert np.all(result == (data[:, 1] > 50))
assert np.all(result.flatten() == (data[:, 1] > 50))
# multiple columns
e = expression.Expression('v1 <= 75')
......@@ -230,3 +231,54 @@ def test_calculateExpressionEvaluationOrder():
vids = [1, 2, 3]
exprs = ['v2 == 1']
expression.calculateExpressionEvaluationOrder(vids, makexprs(exprs))
def test_Expresssion_non_numeric():
test_data = tw.dedent("""
index,10,20,30
1,a,abc,a123
2,b,def,a 101
3,c,,b252
4,d,jkl,b745
5,e,,b254
""").strip()
test_cols = {
10 : '10',
20 : '20',
30 : '30',
}
expr_tests = [
('v10 == "a"', [1, 0, 0, 0, 0]),
("v10 == 'a'", [1, 0, 0, 0, 0]),
('v20 == na', [0, 0, 1, 0, 1]),
('v20 != na', [1, 1, 0, 1, 0]),
('v30 == "a 101"', [0, 1, 0, 0, 0]),
('v30 contains "b2"', [0, 0, 1, 0, 1]),
('v30 contains "b22"', [0, 0, 0, 0, 0]),
# bad
('v10 == a', 'error'),
('v10 == "a', 'error'),
('v10 == a"', 'error'),
]
data = pd.read_csv(io.StringIO(test_data), sep=',')
dt = gen_DataTableFromDataFrame(data)
for expr, expected in expr_tests:
if expected == 'error':
with pytest.raises(pp.ParseException):
e = expression.Expression(expr)
continue
else:
e = expression.Expression(expr)
result = e.evaluate(dt, test_cols)
assert len(result) == len(expected)
assert all([bool(r) == bool(e) for r, e in zip(result, expected)])
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment