Commit 76b54e2a authored by Paul McCarthy's avatar Paul McCarthy 🚵
Browse files

TEST: Test return values from processing functions

parent a1807e3c
......@@ -5,6 +5,7 @@
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
import itertools as it
import multiprocessing as mp
import textwrap as tw
from unittest import mock
......@@ -217,7 +218,7 @@ def test_parseProcesses_run():
@clear_plugins
def test_process_variable_types():
def test_processData_variable_types():
procfile = tw.dedent("""
Variable\tProcess
......@@ -268,3 +269,69 @@ def test_process_variable_types():
[4], [5], [6]]
assert called_on['all_except'] == [[4, 5, 6]]
assert called_on['all_independent_except'] == [[1], [2], [3]]
@clear_plugins
def test_processData_returnValues():
@custom.processor()
def nothing(dtable, vids):
return None
@custom.processor()
def remove(dtable, vids):
cols = list(it.chain(*[dtable.columns(v) for v in vids]))
return cols
@custom.processor()
def add(dtable, vids):
newseries = []
newvids = []
for v in vids:
col = dtable.columns(v)[0]
data = dtable[:, col.name]
newseries.append(pd.Series(data + 10,
name='{}-0.0'.format(v * 10)))
newvids.append(v * 10)
return newseries, newvids
@custom.processor()
def add_and_remove(dtable, vids):
remcols = []
newseries = []
newvids = []
for v in vids:
col = dtable.columns(v)[0]
data = dtable[:, col.name]
newseries.append(pd.Series(data + 10,
name='{}-0.0'.format(v * 10)))
newvids.append(v * 10)
remcols.append(col)
return remcols, newseries, newvids
procfile = tw.dedent("""
Variable\tProcess
1:3\tremove
4:6\tadd
7:9\tadd_and_remove
10:12\tnothing
""").strip()
with tempdir():
open('processing.tsv', 'wt').write(procfile)
gen_test_data(12, 50, 'data.tsv')
proctable = loadtables.loadProcessingTable('processing.tsv')
vartable, _, cattable = gen_tables(range(1, 13))[:3]
dtable, _ = importing.importData('data.tsv',
vartable,
proctable,
cattable)
processing.processData(dtable)
gotcols = [c.name for c in dtable.allColumns[1:]]
expcols = ['{}-0.0'.format(v)
for v in [4, 5, 6, 10, 11, 12, 40, 50, 60, 70, 80, 90]]
assert sorted(expcols) == sorted(gotcols)
......@@ -174,6 +174,31 @@ def test_binariseCateorical():
assert names == [r.name for r in remove]
def test_binariseCateorical_no_replace():
data = np.random.randint(1, 10, (50, 3))
data[:, 0] = np.arange(1, 51)
cols = ['eid', '1-0.0', '2-0.0']
vids = [1, 2]
with tempdir():
np.savetxt('data.txt', data, delimiter=',', header=','.join(cols))
vartable, proctable, cattable = gen_tables(vids)[:3]
dt, _ = importing.importData('data.txt', vartable, proctable, cattable)
add, addvids = pfns.binariseCategorical(
dt, [1], replace=False, nameFormat='{vid}.{value}')
uniq = np.unique(data[:, 1])
assert len(add) == len(uniq)
assert len(addvids) == len(uniq)
assert sorted([c.name for c in add]) == \
sorted(['1.{}'.format(u) for u in uniq])
assert np.all([v == 1 for v in addvids])
def test_binariseCategorical_nonnumeric():
data = [random.choice(string.ascii_letters[:8]) for i in range(40)]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment