Skip to content

Commit 8bfbafd

Browse files
committed
Clean the EvalEnvironment before pickling to removing patsy's stateful transforms which have different names/qualnames from expected.
1 parent 14e62a7 commit 8bfbafd

11 files changed

+30
-15
lines changed

patsy/eval.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def __hash__(self):
262262
tuple(self._namespace_ids())))
263263

264264
def __getstate__(self):
265+
self.clean()
265266
namespaces = self._namespaces
266267
namespaces = _replace_un_pickleable(namespaces)
267268
return (0, namespaces, self.flags)
@@ -272,6 +273,17 @@ def __setstate__(self, pickle):
272273
self.flags = flags
273274
self._namespaces = _return_un_pickleable(namespaces)
274275

276+
def clean(self):
277+
"""The EvalEnvironment doesn't need the stateful transformation
278+
functions once the design matrix has been built. This will delete
279+
it. Called by __getstate__ to prepare for pickling."""
280+
namespaces = []
281+
for namespace in self._namespaces:
282+
ns = {key: namespace[key] for key in six.iterkeys(namespace) if not
283+
hasattr(namespace[key], '__patsy_stateful_transform__')}
284+
namespaces.append(ns)
285+
self._namespaces = namespaces
286+
275287

276288
class ObjectHolder(object):
277289
def __init__(self, kind, module, name):
@@ -489,7 +501,23 @@ def test_EvalEnvironment_eq():
489501
capture_local_env = lambda: EvalEnvironment.capture(0)
490502
env3 = capture_local_env()
491503
env4 = capture_local_env()
492-
assert env3 != env4 # This fails...
504+
assert env3 != env4
505+
506+
507+
def test_EvalEnvironment_clean():
508+
from patsy.state import center, standardize
509+
from patsy.splines import bs
510+
511+
env1 = EvalEnvironment([{'center': center}])
512+
env2 = EvalEnvironment([{'standardize': standardize}])
513+
env3 = EvalEnvironment([{'bs': bs}])
514+
env1.clean()
515+
env2.clean()
516+
env3.clean()
517+
518+
env1._namespaces == [{}]
519+
env2._namespaces == [{}]
520+
env3._namespaces == [{}]
493521

494522
_builtins_dict = {}
495523
six.exec_("from patsy.builtins import *", {}, _builtins_dict)
@@ -650,6 +678,7 @@ def __setstate__(self, pickle):
650678
self.code = pickle['code']
651679
self.origin = pickle['origin']
652680

681+
653682
def test_EvalFactor_pickle_saves_origin():
654683
from patsy.util import assert_pickled_equals
655684
# The pickling tests use object equality before and after pickling

patsy/mgcv_cubic_splines.py

-6
Original file line numberDiff line numberDiff line change
@@ -730,8 +730,6 @@ def __setstate__(self, pickle):
730730

731731

732732
cr = stateful_transform(CR)
733-
cr.__qualname__ = 'cr'
734-
cr.__name__ = 'cr'
735733

736734

737735
class CC(CubicRegressionSpline):
@@ -774,8 +772,6 @@ def __setstate__(self, pickle):
774772

775773

776774
cc = stateful_transform(CC)
777-
cc.__qualname__ = 'cc'
778-
cc.__name__ = 'cc'
779775

780776

781777
def test_crs_errors():
@@ -978,8 +974,6 @@ def __setstate__(self, pickle):
978974

979975

980976
te = stateful_transform(TE)
981-
te.__qualname__ = 'te'
982-
te.__name__ = 'te'
983977

984978

985979
def test_te_errors():

patsy/splines.py

-2
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,6 @@ def __setstate__(self, pickle):
257257

258258

259259
bs = stateful_transform(BS)
260-
bs.__qualname__ = 'bs'
261-
bs.__name__ = 'bs'
262260

263261

264262
def test_bs_compat():

patsy/state.py

-6
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ def __setstate__(self, pickle):
126126

127127

128128
center = stateful_transform(Center)
129-
center.__qualname__ = 'center'
130-
center.__name__ = 'center'
131129

132130
# See:
133131
# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#On-line_algorithm
@@ -196,9 +194,5 @@ def __setstate__(self, pickle):
196194

197195

198196
standardize = stateful_transform(Standardize)
199-
standardize.__qualname__ = 'standardize'
200-
standardize.__name__ = 'standardize'
201197
# R compatibility:
202198
scale = standardize
203-
scale.__qualname__ = 'scale'
204-
scale.__name__ = 'scale'
-29 Bytes
Binary file not shown.
-37 Bytes
Binary file not shown.
Binary file not shown.
-40 Bytes
Binary file not shown.
-27 Bytes
Binary file not shown.
Binary file not shown.
-37 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)