Skip to content

Commit d284f00

Browse files
authored
Merge pull request #221 from weka-io/add-ensure_same_defaults
Add `ensure_same_defaults`
2 parents 3153299 + 3ade4a9 commit d284f00

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
### Added
10+
- `ensure_same_defaults` decorator for setting one function's defaults as source of truth for other function
11+
912
## [0.4.0] - 2019-11-14
1013

1114
### Changed

easypy/decorations.py

+50
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from functools import wraps, partial, update_wrapper
66
from operator import attrgetter
77
from abc import ABCMeta, abstractmethod
8+
import inspect
9+
10+
from easypy.exceptions import TException
811

912

1013
def parametrizeable_decorator(deco):
@@ -138,3 +141,50 @@ def foo(self):
138141
def wrapper(func):
139142
return LazyDecoratorDescriptor(decorator_factory, func, cached)
140143
return wrapper
144+
145+
146+
class DefaultsMismatch(TException):
147+
template = 'The defaults of {func} differ from those of {source_of_truth} in params {param_names}'
148+
149+
150+
def ensure_same_defaults(source_of_truth, ignore=()):
151+
"""
152+
Ensure the decorated function has the same default as the source of truth in optional parameters shared by both
153+
154+
:param source_of_truth: A function to check the defaults against.
155+
:param ignore: A list of parameters to ignore even if they exist and have defaults in both functions.
156+
:raises DefaultsMismatch: When the defaults are different.
157+
158+
>>> def foo(a=1, b=2, c=3):
159+
... ...
160+
>>> @ensure_same_defaults(foo)
161+
... def bar(a=1, b=2, c=3): # these defaults are verified by the decorator
162+
... ...
163+
"""
164+
165+
sot_signature = inspect.signature(source_of_truth)
166+
params_with_defaults = [
167+
param for param in sot_signature.parameters.values()
168+
if param.default is not param.empty
169+
and param.name not in ignore]
170+
171+
def gen_mismatches(func):
172+
signature = inspect.signature(func)
173+
for sot_param in params_with_defaults:
174+
param = signature.parameters.get(sot_param.name)
175+
if param is None:
176+
continue
177+
if param.default is param.empty:
178+
continue
179+
if sot_param.default != param.default:
180+
yield sot_param.name
181+
182+
def wrapper(func):
183+
mismatches = list(gen_mismatches(func))
184+
if mismatches:
185+
raise DefaultsMismatch(
186+
func=func,
187+
source_of_truth=source_of_truth,
188+
param_names=mismatches)
189+
return func
190+
return wrapper

tests/test_decorations.py

+54
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import wraps
44

55
from easypy.decorations import lazy_decorator
6+
from easypy.decorations import ensure_same_defaults, DefaultsMismatch
67
from easypy.misc import kwargs_resilient
78

89

@@ -136,3 +137,56 @@ def counter(self):
136137
foo2.ts += 1
137138
assert [foo1.inc(), foo2.inc()] == [2, 2]
138139
assert [foo1.counter, foo2.counter] == [1, 2] # foo1 was not updated since last sync - only foo2
140+
141+
142+
def test_ensure_same_defaults():
143+
def foo(a=1, b=2, c=3):
144+
return a, b, c
145+
146+
147+
@ensure_same_defaults(foo)
148+
def bar(a=1, b=2, c=3):
149+
return a, b, c
150+
151+
# Test we did not change the actual function
152+
assert foo() == bar()
153+
assert foo(4, 5, 6) == bar(4, 5, 6)
154+
155+
156+
with pytest.raises(DefaultsMismatch) as exc:
157+
@ensure_same_defaults(foo)
158+
def baz(a=1, b=3, c=2):
159+
pass
160+
assert exc.value.param_names == ['b', 'c']
161+
162+
163+
def test_ensure_same_defaults_skipping_params_with_no_default():
164+
@ensure_same_defaults(lambda a=1, b=2: ...)
165+
def foo(a, b=2):
166+
pass
167+
168+
@ensure_same_defaults(lambda a, b=2: ...)
169+
def foo(a=1, b=2):
170+
pass
171+
172+
@ensure_same_defaults(lambda a, b=2: ...)
173+
def foo(a=1, c=3):
174+
pass
175+
176+
with pytest.raises(DefaultsMismatch) as exc:
177+
@ensure_same_defaults(lambda a, b=2: ...)
178+
def foo(a, b=4):
179+
pass
180+
assert exc.value.param_names == ['b']
181+
182+
183+
def test_ensure_same_defaults_ignore():
184+
@ensure_same_defaults(lambda a=1, b=2: ..., ignore=('b',))
185+
def foo(a=1, b=3):
186+
pass
187+
188+
with pytest.raises(DefaultsMismatch) as exc:
189+
@ensure_same_defaults(lambda a=1, b=2: ..., ignore=('b',))
190+
def foo(a=2, b=3):
191+
pass
192+
assert exc.value.param_names == ['a']

0 commit comments

Comments
 (0)