Skip to content

Commit ea9a207

Browse files
author
Scott Sanderson
authored
Merge pull request #9 from llllllllll/wrap-and-unwrap
ENH: py2 compat wrappers
2 parents 5760456 + b460a9b commit ea9a207

File tree

4 files changed

+106
-7
lines changed

4 files changed

+106
-7
lines changed

interface/compat.py

+37-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import sys
23
from itertools import repeat
34

@@ -6,17 +7,50 @@
67
PY2 = version_info.major == 2
78
PY3 = version_info.major == 3
89

9-
if PY2: # pragma: nocover
10+
if PY2: # pragma: nocover-py3
1011
from funcsigs import signature, Parameter
1112

13+
@functools.wraps(functools.wraps)
14+
def wraps(func, *args, **kwargs):
15+
outer_decorator = functools.wraps(func, *args, **kwargs)
16+
17+
def decorator(f):
18+
wrapped = outer_decorator(f)
19+
wrapped.__wrapped__ = func
20+
return wrapped
21+
22+
return decorator
23+
1224
def raise_from(e, from_):
1325
raise e
1426

1527
def viewkeys(d):
1628
return d.viewkeys()
1729

18-
else: # pragma: nocover
19-
from inspect import signature, Parameter
30+
def unwrap(func, stop=None):
31+
# NOTE: implementation is taken from CPython/Lib/inspect.py, Python 3.6
32+
if stop is None:
33+
def _is_wrapper(f):
34+
return hasattr(f, '__wrapped__')
35+
else:
36+
def _is_wrapper(f):
37+
return hasattr(f, '__wrapped__') and not stop(f)
38+
f = func # remember the original func for error reporting
39+
memo = {id(f)} # Memoise by id to tolerate non-hashable objects
40+
while _is_wrapper(func):
41+
func = func.__wrapped__
42+
id_func = id(func)
43+
if id_func in memo:
44+
raise ValueError('wrapper loop when unwrapping {!r}'.format(f))
45+
memo.add(id_func)
46+
return func
47+
48+
49+
else: # pragma: nocover-py2
50+
from inspect import signature, Parameter, unwrap
51+
52+
wraps = functools.wraps
53+
2054
exec("def raise_from(e, from_):" # pragma: nocover
2155
" raise e from from_")
2256

interface/tests/test_interface.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from textwrap import dedent
33

4-
from ..compat import PY3
4+
from ..compat import PY3, wraps
55
from ..interface import implements, InvalidImplementation, Interface, default
66

77

@@ -672,3 +672,51 @@ def default_classmethod(cls, x):
672672
673673
Consider changing the implementation of default_method or making these attributes part of HasDefault.""" # noqa
674674
assert second == expected_second
675+
676+
677+
def test_wrapped_implementation():
678+
class I(Interface): # pragma: nocover
679+
def f(self, a, b, c):
680+
pass
681+
682+
def wrapping_decorator(f):
683+
@wraps(f)
684+
def inner(*args, **kwargs): # pragma: nocover
685+
pass
686+
687+
return inner
688+
689+
class C(implements(I)): # pragma: nocover
690+
@wrapping_decorator
691+
def f(self, a, b, c):
692+
pass
693+
694+
695+
def test_wrapped_implementation_incompatible():
696+
class I(Interface): # pragma: nocover
697+
def f(self, a, b, c):
698+
pass
699+
700+
def wrapping_decorator(f):
701+
@wraps(f)
702+
def inner(*args, **kwargs): # pragma: nocover
703+
pass
704+
705+
return inner
706+
707+
with pytest.raises(InvalidImplementation) as e:
708+
class C(implements(I)): # pragma: nocover
709+
@wrapping_decorator
710+
def f(self, a, b): # missing ``c``
711+
pass
712+
713+
actual_message = str(e.value)
714+
expected_message = dedent(
715+
"""
716+
class C failed to implement interface I:
717+
718+
The following methods of I were implemented with invalid signatures:
719+
- f(self, a, b) != f(self, a, b, c)"""
720+
)
721+
722+
assert actual_message == expected_message

interface/tests/test_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from ..utils import is_a, unique
22

3+
from ..compat import wraps, unwrap
4+
35

46
def test_unique():
57
assert list(unique(iter([1, 3, 1, 2, 3]))) == [1, 3, 2]
@@ -8,3 +10,14 @@ def test_unique():
810
def test_is_a():
911
assert is_a(int)(5)
1012
assert not is_a(str)(5)
13+
14+
15+
def test_wrap_and_unwrap():
16+
def f(a, b, c): # pragma: nocover
17+
pass
18+
19+
@wraps(f)
20+
def g(*args): # pragma: nocover
21+
pass
22+
23+
assert unwrap(g) is f

interface/typed_signature.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
import types
88

9-
from .compat import signature
9+
from .compat import signature, unwrap
1010
from .default import default
1111

1212

@@ -48,7 +48,7 @@ def __str__(self):
4848
BUILTIN_FUNCTION_TYPES = (types.FunctionType, types.BuiltinFunctionType)
4949

5050

51-
def extract_func(obj):
51+
def _inner_extract_func(obj):
5252
if isinstance(obj, BUILTIN_FUNCTION_TYPES):
5353
# Fast path, since this is the most likely case.
5454
return obj
@@ -57,6 +57,10 @@ def extract_func(obj):
5757
elif isinstance(obj, property):
5858
return obj.fget
5959
elif isinstance(obj, default):
60-
return extract_func(obj.implementation)
60+
return _inner_extract_func(obj.implementation)
6161
else:
6262
return obj
63+
64+
65+
def extract_func(obj):
66+
return unwrap(_inner_extract_func(obj))

0 commit comments

Comments
 (0)