Commit 78fc6f9f authored by Paul McCarthy's avatar Paul McCarthy 🚵
Browse files

TEST: Test column combine inn subject expressions

parent 47b5d440
......@@ -238,14 +238,17 @@ def gen_DataTable(cols, *a, **kwa):
return gen_DataTableFromDataFrame(data, *a, **kwa)
def gen_DataTableFromDataFrame(df, tables=None, pool=None):
def gen_DataTableFromDataFrame(df, tables=None, pool=None, variables=None):
if variables is None:
variables = list(range(1, len(df.columns) + 1))
variables = list(range(1, len(df.columns) + 1))
colobjs = [datatable.Column(None, df.index.name, 0, 0, 0, 0)] + \
[datatable.Column(None, n, v, v, 0, 0)
for v, n in zip(variables, df.columns)]
if tables is None:
variables = list(set(variables))
vartable, proctable, cattable, uvs = gen_tables(variables)
else:
vartable, proctable, cattable = tables
......
......@@ -151,6 +151,12 @@ def test_Expression_multiple_columns():
exp = (data[:, 1:3] > 50).any(axis=1) & (data[:, 4:] < 50).all(axis=1)
assert np.all(result == exp)
# column length mismatch - error
e = expression.Expression('v1 > 50 && v2 < 50')
with pytest.raises(ValueError):
result = e.evaluate(dtable, {1 : ['1-0.0', '1-1.0', '1-2.0'],
2 : ['2-0.0', '2-1.0']})
......
......@@ -5,11 +5,14 @@
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
import pytest
import pandas as pd
import numpy as np
import funpack.importing as importing
from . import gen_DataTable
from . import gen_DataTable, gen_DataTableFromDataFrame
def test_removeSubjects():
......@@ -54,3 +57,51 @@ def test_removeSubjects():
mask = (data[0, :] > 5) | (data[1, :] == 9)
assert np.all(dtable.index == np.where(mask)[0] + 1)
assert np.all(dtable[:, :] == data[:, mask].T)
def test_removeSubjects_multiple_columns():
def gendata():
cols = ['eid', '1-0.0', '1-1.0', '1-2.0', '2-0.0', '2-1.0']
variables = [ 1, 1, 1, 2, 2]
data = np.random.randint(1, 10, (6, 500))
data[0, :] = np.arange(1, 501)
df = pd.DataFrame({c : d for c, d in zip(cols, data)})
df = df.set_index('eid')
data = data[1:, :].T
return gen_DataTableFromDataFrame(df, variables=variables), data
def all(s): return s.all(axis=1)
def any(s): return s.any(axis=1)
# combine vars with ncolumn
# mismatch - error
dtable = gendata()[0]
exprs = ['v1 > 2 && v2 < 7']
with pytest.raises(ValueError):
importing.removeSubjects(dtable, exprs=exprs)
# combine columns within var
dtable, data = gendata()
exprs = ['all(v1 > 2) && any(v2 < 7)']
exp = all(data[:, :3] > 2) & any(data[:, 3:] < 7)
exp = dtable.index[exp]
importing.removeSubjects(dtable, exprs=exprs)
assert (dtable.index == exp).all()
# no combining columns - should
# default to any
dtable, data = gendata()
exprs = ['v1 > 6']
exp = any(data[:, :3] > 6)
exp = dtable.index[exp]
importing.removeSubjects(dtable, exprs=exprs)
assert (dtable.index == exp).all()
# multipler expressions - ORed together
dtable, data = gendata()
exprs = ['v1 > 6', 'v2 < 4']
exp = any(data[:, :3] > 6) | any(data[:, 3:] < 4)
exp = dtable.index[exp]
importing.removeSubjects(dtable, exprs=exprs)
assert (dtable.index == exp).all()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment