Skip to content

Commit

Permalink
Fix stubtest mypy enum.Flag edge case (#15933)
Browse files Browse the repository at this point in the history
Fix edge-case stubtest crashes when an instance of an enum.Flag that is not a
member of that enum.Flag is used as a parameter default

Fixes #15923.

Note: the test cases I've added reproduce the crash, but only if you're
using a compiled version of mypy. (Some of them only repro the crash on
<=py310, but some repro it on py311+ as well.)

We run stubtest tests in CI with compiled mypy, so they do repro the
crash in the context of our CI.
  • Loading branch information
AlexWaygood authored Aug 23, 2023
1 parent 7141d6b commit 48835a3
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,7 @@ def anytype() -> mypy.types.AnyType:
value: bool | int | str
if isinstance(runtime, bytes):
value = bytes_to_human_readable_repr(runtime)
elif isinstance(runtime, enum.Enum):
elif isinstance(runtime, enum.Enum) and isinstance(runtime.name, str):
value = runtime.name
elif isinstance(runtime, (bool, int, str)):
value = runtime
Expand Down
103 changes: 100 additions & 3 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, name: str) -> None: ...
class Coroutine(Generic[_T_co, _S, _R]): ...
class Iterable(Generic[_T_co]): ...
class Iterator(Iterable[_T_co]): ...
class Mapping(Generic[_K, _V]): ...
class Match(Generic[AnyStr]): ...
class Sequence(Iterable[_T_co]): ...
Expand All @@ -86,7 +87,9 @@ def __init__(self) -> None: pass
def __repr__(self) -> str: pass
class type: ...
class tuple(Sequence[T_co], Generic[T_co]): ...
class tuple(Sequence[T_co], Generic[T_co]):
def __ge__(self, __other: tuple[T_co, ...]) -> bool: pass
class dict(Mapping[KT, VT]): ...
class function: pass
Expand All @@ -105,6 +108,39 @@ def classmethod(f: T) -> T: ...
def staticmethod(f: T) -> T: ...
"""

stubtest_enum_stub = """
import sys
from typing import Any, TypeVar, Iterator
_T = TypeVar('_T')
class EnumMeta(type):
def __len__(self) -> int: pass
def __iter__(self: type[_T]) -> Iterator[_T]: pass
def __reversed__(self: type[_T]) -> Iterator[_T]: pass
def __getitem__(self: type[_T], name: str) -> _T: pass
class Enum(metaclass=EnumMeta):
def __new__(cls: type[_T], value: object) -> _T: pass
def __repr__(self) -> str: pass
def __str__(self) -> str: pass
def __format__(self, format_spec: str) -> str: pass
def __hash__(self) -> Any: pass
def __reduce_ex__(self, proto: Any) -> Any: pass
name: str
value: Any
class Flag(Enum):
def __or__(self: _T, other: _T) -> _T: pass
def __and__(self: _T, other: _T) -> _T: pass
def __xor__(self: _T, other: _T) -> _T: pass
def __invert__(self: _T) -> _T: pass
if sys.version_info >= (3, 11):
__ror__ = __or__
__rand__ = __and__
__rxor__ = __xor__
"""


def run_stubtest(
stub: str, runtime: str, options: list[str], config_file: str | None = None
Expand All @@ -114,6 +150,8 @@ def run_stubtest(
f.write(stubtest_builtins_stub)
with open("typing.pyi", "w") as f:
f.write(stubtest_typing_stub)
with open("enum.pyi", "w") as f:
f.write(stubtest_enum_stub)
with open(f"{TEST_MODULE_NAME}.pyi", "w") as f:
f.write(stub)
with open(f"{TEST_MODULE_NAME}.py", "w") as f:
Expand Down Expand Up @@ -954,23 +992,82 @@ def fizz(self): pass

@collect_cases
def test_enum(self) -> Iterator[Case]:
yield Case(stub="import enum", runtime="import enum", error=None)
yield Case(
stub="""
import enum
class X(enum.Enum):
a: int
b: str
c: str
""",
runtime="""
import enum
class X(enum.Enum):
a = 1
b = "asdf"
c = 2
""",
error="X.c",
)
yield Case(
stub="""
class Flags1(enum.Flag):
a: int
b: int
def foo(x: Flags1 = ...) -> None: ...
""",
runtime="""
class Flags1(enum.Flag):
a = 1
b = 2
def foo(x=Flags1.a|Flags1.b): pass
""",
error=None,
)
yield Case(
stub="""
class Flags2(enum.Flag):
a: int
b: int
def bar(x: Flags2 | None = None) -> None: ...
""",
runtime="""
class Flags2(enum.Flag):
a = 1
b = 2
def bar(x=Flags2.a|Flags2.b): pass
""",
error="bar",
)
yield Case(
stub="""
class Flags3(enum.Flag):
a: int
b: int
def baz(x: Flags3 | None = ...) -> None: ...
""",
runtime="""
class Flags3(enum.Flag):
a = 1
b = 2
def baz(x=Flags3(0)): pass
""",
error=None,
)
yield Case(
stub="""
class Flags4(enum.Flag):
a: int
b: int
def spam(x: Flags4 | None = None) -> None: ...
""",
runtime="""
class Flags4(enum.Flag):
a = 1
b = 2
def spam(x=Flags4(0)): pass
""",
error="spam",
)

@collect_cases
def test_decorator(self) -> Iterator[Case]:
Expand Down

0 comments on commit 48835a3

Please sign in to comment.