Commit 47b5d440 authored by Paul McCarthy's avatar Paul McCarthy 🚵
Browse files

RF,ENH: removeSubjects logic updated to perform visit/instance testing in

parallel, and to handle expressions which do not combine their results across
columns
parent eb0171d5
......@@ -9,7 +9,6 @@ the data importing stage of the ``funpack`` sequence
"""
import os.path as op
import itertools as it
import functools as ft
import multiprocessing.dummy as mpd
......@@ -258,56 +257,98 @@ def removeSubjects(dtable, exclude=None, exprs=None):
else: mask = np.zeros(orignrows, dtype=np.bool)
if exprs is not None:
# Parse the expressions, and get a
# list of all variables that are
# mentioned in them.
exprs = list(it.chain(*[e.split(',') for e in exprs]))
exprs = [expression.Expression(e) for e in exprs]
vids = list(set(it.chain(*[e.variables for e in exprs])))
# list of the variables that are
# mentioned in each of them,
exprs = list(it.chain(*[e.split(',') for e in exprs]))
exprs = [expression.Expression(e) for e in exprs]
vids = [list(e.variables) for e in exprs]
# Build a list of the visits and
# instances in the data for each
# variable used in the expression.
# variable used in each expression.
try:
visits = [dtable.visits( v) for v in vids]
instances = [dtable.instances(v) for v in vids]
visits = [[dtable.visits( v) for v in evs] for evs in vids]
instances = [[dtable.instances(v) for v in evs] for evs in vids]
except KeyError as e:
raise RuntimeError('Unknown variable used in exclude expression: '
'{} ({})'.format(exprs, e))
# Calculate the intersection of visits/
# instances across all variables - we
# evaluate expressions for each visit/
# instance, and only where a visit/
# instance is present for all variables.
# evaluate an expression only on visits/
# instances present for all variables
# in that expression. All other visits/
# instances are not considered.
def intersection(a, b):
return set(a).intersection(b)
intersection = ft.partial(ft.reduce, intersection)
visits = [intersection(evis) for evis in visits]
instances = [intersection(eis) for eis in instances]
# Build a {vid : [column]} dict for
# each expression, as we need such
# a dict to evaluate them.
exprcols = []
for i in range(len(exprs)):
evs = vids[ i]
evis = visits[ i]
eis = instances[i]
cols = collections.defaultdict(list)
for evid, evisit, einstance in it.product(evs, evis, eis):
cols[evid].extend(dtable.columns(evid, evisit, einstance))
exprcols.append(cols)
# List which will contain one boolean
# numpy array for each subject include
# expression.
exprmasks = []
if len(visits) > 0: visits = ft.reduce(intersection, visits)
if len(instances) > 0: instances = ft.reduce(intersection, instances)
# evalute each expression in parallel
with dtable.pool() as pool:
for i, expr in enumerate(exprs):
# A subject will be retained if *any*
# expression for *any* visit/instance
# evaluates to true.
exprmasks = []
cols = exprcols[i]
for visit, instance in it.product(visits, instances):
if len(cols) == 0:
log.debug('Ignoring expression (%s) - no associated '
'columns are present', str(expr))
continue
colnames = {v : [c.name for c in vcols]
for v, vcols in cols.items()}
cols = list(it.chain(*cols.values()))
subtable = dtable.subtable(cols)
# build a dict of { vid : column } mappings
# for each variable used in the expression
cols = [dtable.columns(v, visit, instance)[0] for v in vids]
cols = {v : c.name for v, c in zip(vids, cols)}
log.debug('Evaluating expression (%s) on columns %s',
expr, colnames)
with dtable.pool() as pool:
for e in exprs:
exprmasks.append(pool.apply_async(
e.evaluate, (dtable, cols, )))
exprmasks.append(pool.apply_async(
expr.evaluate, (subtable, colnames)))
# wait for each expression to complete,
# then combine them using logical OR.
# wait for each expression to complete
exprmasks = [e.get() for e in exprmasks]
mask = ft.reduce(lambda a, b: a | b, exprmasks, mask)
mask = np.array(mask)
# any result which was not combined using
# any() or all() defaults to being combined
# with any(). For example, if "v123 >= 2"
# is applied to columns 123-0.0, 123-1.0,
# and 123-2.0, the final result will be
# a 1D boolean array containing True where
# any of the three columns were >= 2.
for i, em in enumerate(exprmasks):
if len(em.shape) == 2:
exprmasks[i] = em.any(axis=1)
# Finally, all expressions are combined
# in the same manner - i.e. rows which
# passed *any* of the expressions
# are included
mask = ft.reduce(lambda a, b: a | b, exprmasks, mask)
mask = np.array(mask)
# Flag subjects to drop
if exclude is not None:
......@@ -407,7 +448,7 @@ def columnsToLoad(datafiles,
# Turn the unknonwVars list
# into a list of variable IDs
unknownVids = list(sorted(set([c.vid for c in unknownVars])))
unknownVids = list(sorted({c.vid for c in unknownVars}))
if isinstance(datafiles, six.string_types):
datafiles = [datafiles]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment