Skip to content

Commit b0dc258

Browse files
committed
Added partial function
1 parent 72e9725 commit b0dc258

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

patsy/design_info.py

+102
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from patsy.constraint import linear_constraint
3737
from patsy.contrasts import ContrastMatrix
3838
from patsy.desc import ModelDesc, Term
39+
from collections import OrderedDict
3940

4041
class FactorInfo(object):
4142
"""A FactorInfo object is a simple class that provides some metadata about
@@ -684,6 +685,49 @@ def var_names(self, eval_env=0):
684685
else:
685686
return {}
686687

688+
def partial(self, columns, product=False):
689+
"""Returns a partial prediction array where only the variables in the
690+
dict ``columns`` are tranformed per the :class:`DesignInfo`
691+
transformations. The terms that are not influenced by ``columns``
692+
return as zero.
693+
694+
This is useful to perform a partial prediction on unseen data and to
695+
view marginal differences in factors.
696+
697+
:arg columns: A dict with the keys as the column names for the marginal
698+
predictions desired and values as the marginal values to be predicted.
699+
700+
:arg product: When `True`, the resturned numpy array represents the
701+
Cartesian product of the values ``columns``.
702+
703+
:returns: A numpy array of the partial design matrix.
704+
"""
705+
from .highlevel import dmatrix
706+
if product:
707+
columns = _column_product(columns)
708+
rows = None
709+
for col in columns:
710+
if rows and rows != len(columns[col]):
711+
raise ValueError('all columns must be of same length')
712+
rows = len(columns[col])
713+
parts = []
714+
for term, subterm in six.iteritems(self.term_codings):
715+
term_vars = term.var_names()
716+
present = True
717+
for term_var in term_vars:
718+
if term_var not in columns:
719+
present = False
720+
if present and (term.name() != 'Intercept'):
721+
# This seems like an inelegent way to not having the Intercept
722+
# in the output
723+
di = self.subset('0 + {}'.format(term.name()))
724+
parts.append(dmatrix(di, columns))
725+
else:
726+
num_columns = np.sum(s.num_columns for s in subterm)
727+
dm = np.zeros((rows, num_columns))
728+
parts.append(dm)
729+
return np.hstack(parts)
730+
687731
@classmethod
688732
def from_array(cls, array_like, default_column_prefix="column"):
689733
"""Find or construct a DesignInfo appropriate for a given array_like.
@@ -1230,3 +1274,61 @@ def test_design_matrix():
12301274
repr(DesignMatrix(np.zeros((1, 0))))
12311275
repr(DesignMatrix(np.zeros((0, 1))))
12321276
repr(DesignMatrix(np.zeros((0, 0))))
1277+
1278+
1279+
def test_DesignInfo_partial():
1280+
from .highlevel import dmatrix
1281+
from numpy.testing import assert_allclose
1282+
a = np.array(['a', 'b', 'a', 'b', 'a', 'a', 'b', 'a'])
1283+
b = np.array([1, 3, 2, 4, 1, 3, 1, 1])
1284+
c = np.array([4, 3, 2, 1, 6, 4, 2, 1])
1285+
dm = dmatrix('a + bs(b, df=3, degree=3) + np.log(c)')
1286+
x = np.zeros((3, 6))
1287+
x[1, 1] = 1
1288+
y = dm.design_info.partial({'a': ['a', 'b', 'a']})
1289+
assert_allclose(x, y)
1290+
1291+
x = np.zeros((2, 6))
1292+
x[1, 1] = 1
1293+
x[1, 5] = np.log(3)
1294+
y = dm.design_info.partial({'a': ['a', 'b'], 'c': [1, 3]})
1295+
assert_allclose(x, y)
1296+
1297+
x = np.zeros((4, 6))
1298+
x[2, 1] = 1
1299+
x[3, 1] = 1
1300+
x[1, 5] = np.log(3)
1301+
x[3, 5] = np.log(3)
1302+
y = dm.design_info.partial({'a': ['a', 'b'], 'c': [1, 3]}, product=True)
1303+
assert_allclose(x, y)
1304+
1305+
dm = dmatrix('a * b')
1306+
y = dm.design_info.partial({'a': ['a', 'b'], 'b': [1, 3]})
1307+
x = np.array([[0, 0, 1, 0], [0, 1, 3, 3]])
1308+
assert_allclose(x, y)
1309+
1310+
from nose.tools import assert_raises
1311+
assert_raises(ValueError, dm.design_info.partial, {'a': ['a', 'b'],
1312+
'b': [1, 2, 3]})
1313+
1314+
1315+
def _column_product(columns):
1316+
from itertools import product
1317+
cols = []
1318+
values = []
1319+
for col, value in six.iteritems(columns):
1320+
cols.append(col)
1321+
values.append(value)
1322+
values = [value for value in product(*values)]
1323+
values = [value for value in zip(*values)]
1324+
return OrderedDict([(col, list(value))
1325+
for col, value in zip(cols, values)])
1326+
1327+
1328+
def test_column_product():
1329+
x = OrderedDict([('a', [1, 2, 3]), ('b', ['a', 'b'])])
1330+
y = OrderedDict([('a', [1, 1, 2, 2, 3, 3]),
1331+
('b', ['a', 'b', 'a', 'b', 'a', 'b'])])
1332+
x = _column_product(x)
1333+
assert x['a'] == y['a']
1334+
assert x['b'] == y['b']

0 commit comments

Comments
 (0)