Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added logic to handle modules and user-defined functions
Browse files Browse the repository at this point in the history
thequackdaddy committed Nov 3, 2018

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 5f662a9 commit 807cc93
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions patsy/design_info.py
Original file line number Diff line number Diff line change
@@ -685,7 +685,7 @@ def var_names(self, eval_env=0):
else:
return {}

def partial(self, columns, product=False):
def partial(self, columns, product=False, eval_env=0):
"""Returns a partial prediction array where only the variables in the
dict ``columns`` are tranformed per the :class:`DesignInfo`
transformations. The terms that are not influenced by ``columns``
@@ -703,6 +703,18 @@ def partial(self, columns, product=False):
:returns: A numpy array of the partial design matrix.
"""
from .highlevel import dmatrix
from types import ModuleType

if not eval_env:
from patsy.eval import EvalEnvironment
eval_env = EvalEnvironment.capture(eval_env, reference=1)

# We need to get rid of the non-callable items from the eval_env
namespaces = [{key: value} for ns in eval_env._namespaces
for key, value in six.iteritems(ns)
if callable(value) or isinstance(value, ModuleType)]
eval_env._namespaces = namespaces

if product:
columns = _column_product(columns)
rows = None
@@ -712,7 +724,7 @@ def partial(self, columns, product=False):
rows = len(columns[col])
parts = []
for term, subterm in six.iteritems(self.term_codings):
term_vars = term.var_names()
term_vars = term.var_names(eval_env)
present = True
for term_var in term_vars:
if term_var not in columns:
@@ -1312,6 +1324,16 @@ def test_DesignInfo_partial():
assert_raises(ValueError, dm.design_info.partial, {'a': ['a', 'b'],
'b': [1, 2, 3]})

def some_function(x):
return np.where(x > 2, 1, 2)

dm = dmatrix('1 + some_function(c)')
x = np.array([[0, 2],
[0, 2],
[0, 1]])
y = dm.design_info.partial({'c': np.array([1, 2, 3])})
assert_allclose(x, y)


def _column_product(columns):
from itertools import product

0 comments on commit 807cc93

Please sign in to comment.