Skip to content

Commit

Permalink
Merge pull request #9 from llllllllll/wrap-and-unwrap
Browse files Browse the repository at this point in the history
ENH: py2 compat wrappers
  • Loading branch information
Scott Sanderson authored Dec 6, 2017
2 parents 5760456 + b460a9b commit ea9a207
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 7 deletions.
40 changes: 37 additions & 3 deletions interface/compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import sys
from itertools import repeat

Expand All @@ -6,17 +7,50 @@
PY2 = version_info.major == 2
PY3 = version_info.major == 3

if PY2: # pragma: nocover
if PY2: # pragma: nocover-py3
from funcsigs import signature, Parameter

@functools.wraps(functools.wraps)
def wraps(func, *args, **kwargs):
outer_decorator = functools.wraps(func, *args, **kwargs)

def decorator(f):
wrapped = outer_decorator(f)
wrapped.__wrapped__ = func
return wrapped

return decorator

def raise_from(e, from_):
raise e

def viewkeys(d):
return d.viewkeys()

else: # pragma: nocover
from inspect import signature, Parameter
def unwrap(func, stop=None):
# NOTE: implementation is taken from CPython/Lib/inspect.py, Python 3.6
if stop is None:
def _is_wrapper(f):
return hasattr(f, '__wrapped__')
else:
def _is_wrapper(f):
return hasattr(f, '__wrapped__') and not stop(f)
f = func # remember the original func for error reporting
memo = {id(f)} # Memoise by id to tolerate non-hashable objects
while _is_wrapper(func):
func = func.__wrapped__
id_func = id(func)
if id_func in memo:
raise ValueError('wrapper loop when unwrapping {!r}'.format(f))
memo.add(id_func)
return func


else: # pragma: nocover-py2
from inspect import signature, Parameter, unwrap

wraps = functools.wraps

exec("def raise_from(e, from_):" # pragma: nocover
" raise e from from_")

Expand Down
50 changes: 49 additions & 1 deletion interface/tests/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from textwrap import dedent

from ..compat import PY3
from ..compat import PY3, wraps
from ..interface import implements, InvalidImplementation, Interface, default


Expand Down Expand Up @@ -672,3 +672,51 @@ def default_classmethod(cls, x):
Consider changing the implementation of default_method or making these attributes part of HasDefault.""" # noqa
assert second == expected_second


def test_wrapped_implementation():
class I(Interface): # pragma: nocover
def f(self, a, b, c):
pass

def wrapping_decorator(f):
@wraps(f)
def inner(*args, **kwargs): # pragma: nocover
pass

return inner

class C(implements(I)): # pragma: nocover
@wrapping_decorator
def f(self, a, b, c):
pass


def test_wrapped_implementation_incompatible():
class I(Interface): # pragma: nocover
def f(self, a, b, c):
pass

def wrapping_decorator(f):
@wraps(f)
def inner(*args, **kwargs): # pragma: nocover
pass

return inner

with pytest.raises(InvalidImplementation) as e:
class C(implements(I)): # pragma: nocover
@wrapping_decorator
def f(self, a, b): # missing ``c``
pass

actual_message = str(e.value)
expected_message = dedent(
"""
class C failed to implement interface I:
The following methods of I were implemented with invalid signatures:
- f(self, a, b) != f(self, a, b, c)"""
)

assert actual_message == expected_message
13 changes: 13 additions & 0 deletions interface/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ..utils import is_a, unique

from ..compat import wraps, unwrap


def test_unique():
assert list(unique(iter([1, 3, 1, 2, 3]))) == [1, 3, 2]
Expand All @@ -8,3 +10,14 @@ def test_unique():
def test_is_a():
assert is_a(int)(5)
assert not is_a(str)(5)


def test_wrap_and_unwrap():
def f(a, b, c): # pragma: nocover
pass

@wraps(f)
def g(*args): # pragma: nocover
pass

assert unwrap(g) is f
10 changes: 7 additions & 3 deletions interface/typed_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
import types

from .compat import signature
from .compat import signature, unwrap
from .default import default


Expand Down Expand Up @@ -48,7 +48,7 @@ def __str__(self):
BUILTIN_FUNCTION_TYPES = (types.FunctionType, types.BuiltinFunctionType)


def extract_func(obj):
def _inner_extract_func(obj):
if isinstance(obj, BUILTIN_FUNCTION_TYPES):
# Fast path, since this is the most likely case.
return obj
Expand All @@ -57,6 +57,10 @@ def extract_func(obj):
elif isinstance(obj, property):
return obj.fget
elif isinstance(obj, default):
return extract_func(obj.implementation)
return _inner_extract_func(obj.implementation)
else:
return obj


def extract_func(obj):
return unwrap(_inner_extract_func(obj))

0 comments on commit ea9a207

Please sign in to comment.