|
36 | 36 | from patsy.constraint import linear_constraint
|
37 | 37 | from patsy.contrasts import ContrastMatrix
|
38 | 38 | from patsy.desc import ModelDesc, Term
|
| 39 | +from collections import OrderedDict |
39 | 40 |
|
40 | 41 | class FactorInfo(object):
|
41 | 42 | """A FactorInfo object is a simple class that provides some metadata about
|
@@ -684,6 +685,49 @@ def var_names(self, eval_env=0):
|
684 | 685 | else:
|
685 | 686 | return {}
|
686 | 687 |
|
| 688 | + def partial(self, columns, product=False): |
| 689 | + """Returns a partial prediction array where only the variables in the |
| 690 | + dict ``columns`` are tranformed per the :class:`DesignInfo` |
| 691 | + transformations. The terms that are not influenced by ``columns`` |
| 692 | + return as zero. |
| 693 | +
|
| 694 | + This is useful to perform a partial prediction on unseen data and to |
| 695 | + view marginal differences in factors. |
| 696 | +
|
| 697 | + :arg columns: A dict with the keys as the column names for the marginal |
| 698 | + predictions desired and values as the marginal values to be predicted. |
| 699 | +
|
| 700 | + :arg product: When `True`, the resturned numpy array represents the |
| 701 | + Cartesian product of the values ``columns``. |
| 702 | +
|
| 703 | + :returns: A numpy array of the partial design matrix. |
| 704 | + """ |
| 705 | + from .highlevel import dmatrix |
| 706 | + if product: |
| 707 | + columns = _column_product(columns) |
| 708 | + rows = None |
| 709 | + for col in columns: |
| 710 | + if rows and rows != len(columns[col]): |
| 711 | + raise ValueError('all columns must be of same length') |
| 712 | + rows = len(columns[col]) |
| 713 | + parts = [] |
| 714 | + for term, subterm in six.iteritems(self.term_codings): |
| 715 | + term_vars = term.var_names() |
| 716 | + present = True |
| 717 | + for term_var in term_vars: |
| 718 | + if term_var not in columns: |
| 719 | + present = False |
| 720 | + if present and (term.name() != 'Intercept'): |
| 721 | + # This seems like an inelegent way to not having the Intercept |
| 722 | + # in the output |
| 723 | + di = self.subset('0 + {}'.format(term.name())) |
| 724 | + parts.append(dmatrix(di, columns)) |
| 725 | + else: |
| 726 | + num_columns = np.sum(s.num_columns for s in subterm) |
| 727 | + dm = np.zeros((rows, num_columns)) |
| 728 | + parts.append(dm) |
| 729 | + return np.hstack(parts) |
| 730 | + |
687 | 731 | @classmethod
|
688 | 732 | def from_array(cls, array_like, default_column_prefix="column"):
|
689 | 733 | """Find or construct a DesignInfo appropriate for a given array_like.
|
@@ -1230,3 +1274,61 @@ def test_design_matrix():
|
1230 | 1274 | repr(DesignMatrix(np.zeros((1, 0))))
|
1231 | 1275 | repr(DesignMatrix(np.zeros((0, 1))))
|
1232 | 1276 | repr(DesignMatrix(np.zeros((0, 0))))
|
| 1277 | + |
| 1278 | + |
| 1279 | +def test_DesignInfo_partial(): |
| 1280 | + from .highlevel import dmatrix |
| 1281 | + from numpy.testing import assert_allclose |
| 1282 | + a = np.array(['a', 'b', 'a', 'b', 'a', 'a', 'b', 'a']) |
| 1283 | + b = np.array([1, 3, 2, 4, 1, 3, 1, 1]) |
| 1284 | + c = np.array([4, 3, 2, 1, 6, 4, 2, 1]) |
| 1285 | + dm = dmatrix('a + bs(b, df=3, degree=3) + np.log(c)') |
| 1286 | + x = np.zeros((3, 6)) |
| 1287 | + x[1, 1] = 1 |
| 1288 | + y = dm.design_info.partial({'a': ['a', 'b', 'a']}) |
| 1289 | + assert_allclose(x, y) |
| 1290 | + |
| 1291 | + x = np.zeros((2, 6)) |
| 1292 | + x[1, 1] = 1 |
| 1293 | + x[1, 5] = np.log(3) |
| 1294 | + y = dm.design_info.partial({'a': ['a', 'b'], 'c': [1, 3]}) |
| 1295 | + assert_allclose(x, y) |
| 1296 | + |
| 1297 | + x = np.zeros((4, 6)) |
| 1298 | + x[2, 1] = 1 |
| 1299 | + x[3, 1] = 1 |
| 1300 | + x[1, 5] = np.log(3) |
| 1301 | + x[3, 5] = np.log(3) |
| 1302 | + y = dm.design_info.partial({'a': ['a', 'b'], 'c': [1, 3]}, product=True) |
| 1303 | + assert_allclose(x, y) |
| 1304 | + |
| 1305 | + dm = dmatrix('a * b') |
| 1306 | + y = dm.design_info.partial({'a': ['a', 'b'], 'b': [1, 3]}) |
| 1307 | + x = np.array([[0, 0, 1, 0], [0, 1, 3, 3]]) |
| 1308 | + assert_allclose(x, y) |
| 1309 | + |
| 1310 | + from nose.tools import assert_raises |
| 1311 | + assert_raises(ValueError, dm.design_info.partial, {'a': ['a', 'b'], |
| 1312 | + 'b': [1, 2, 3]}) |
| 1313 | + |
| 1314 | + |
| 1315 | +def _column_product(columns): |
| 1316 | + from itertools import product |
| 1317 | + cols = [] |
| 1318 | + values = [] |
| 1319 | + for col, value in six.iteritems(columns): |
| 1320 | + cols.append(col) |
| 1321 | + values.append(value) |
| 1322 | + values = [value for value in product(*values)] |
| 1323 | + values = [value for value in zip(*values)] |
| 1324 | + return OrderedDict([(col, list(value)) |
| 1325 | + for col, value in zip(cols, values)]) |
| 1326 | + |
| 1327 | + |
| 1328 | +def test_column_product(): |
| 1329 | + x = OrderedDict([('a', [1, 2, 3]), ('b', ['a', 'b'])]) |
| 1330 | + y = OrderedDict([('a', [1, 1, 2, 2, 3, 3]), |
| 1331 | + ('b', ['a', 'b', 'a', 'b', 'a', 'b'])]) |
| 1332 | + x = _column_product(x) |
| 1333 | + assert x['a'] == y['a'] |
| 1334 | + assert x['b'] == y['b'] |
0 commit comments