From a1beb2c4d5466f60ca008a8bde42f5610b38e4a3 Mon Sep 17 00:00:00 2001 From: Alessio Bogon <778703+youtux@users.noreply.github.com> Date: Sun, 21 Jan 2024 14:07:01 +0100 Subject: [PATCH] Use WeakKeyDictionary instead of storing private attribute `__scenario__` --- src/pytest_bdd/generation.py | 10 ++++++++-- src/pytest_bdd/scenario.py | 12 +++++++----- src/pytest_bdd/utils.py | 11 +++++++++++ tests/feature/test_description.py | 7 +++++-- 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/pytest_bdd/generation.py b/src/pytest_bdd/generation.py index aa24bf29..bc6d1cf2 100644 --- a/src/pytest_bdd/generation.py +++ b/src/pytest_bdd/generation.py @@ -9,7 +9,13 @@ from mako.lookup import TemplateLookup from .feature import get_features -from .scenario import inject_fixturedefs_for_step, make_python_docstring, make_python_name, make_string_literal +from .scenario import ( + inject_fixturedefs_for_step, + make_python_docstring, + make_python_name, + make_string_literal, + scenario_wrapper_template_registry, +) from .steps import get_step_fixture_name from .types import STEP_TYPES @@ -177,7 +183,7 @@ def _show_missing_code_main(config: Config, session: Session) -> None: features, scenarios, steps = parse_feature_files(config.option.features) for item in session.items: - if scenario := getattr(item.obj, "__scenario__", None): + if (scenario := scenario_wrapper_template_registry.get(item.obj)) is not None: if scenario in scenarios: scenarios.remove(scenario) for step in scenario.steps: diff --git a/src/pytest_bdd/scenario.py b/src/pytest_bdd/scenario.py index d3a066eb..926100be 100644 --- a/src/pytest_bdd/scenario.py +++ b/src/pytest_bdd/scenario.py @@ -17,6 +17,7 @@ import os import re from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, cast +from weakref import WeakKeyDictionary import pytest from _pytest.fixtures import FixtureDef, FixtureManager, FixtureRequest, call_fixture_func @@ -26,7 +27,7 @@ from . import exceptions from .feature import get_feature, get_features from .steps import StepFunctionContext, get_step_fixture_name, inject_fixture, step_function_context_registry -from .utils import CONFIG_STACK, get_args, get_caller_module_locals, get_caller_module_path +from .utils import CONFIG_STACK, get_args, get_caller_module_locals, get_caller_module_path, registry_get_safe if TYPE_CHECKING: from _pytest.mark.structures import ParameterSet @@ -42,6 +43,8 @@ PYTHON_REPLACE_REGEX = re.compile(r"\W") ALPHA_REGEX = re.compile(r"^\d+_*") +scenario_wrapper_template_registry: WeakKeyDictionary[Callable[..., Any], ScenarioTemplate] = WeakKeyDictionary() + def find_fixturedefs_for_step(step: Step, fixturemanager: FixtureManager, nodeid: str) -> Iterable[FixtureDef[Any]]: """Find the fixture defs that can parse a step.""" @@ -237,8 +240,7 @@ def scenario_wrapper(request: FixtureRequest, _pytest_bdd_example: dict[str, str scenario_wrapper.__doc__ = f"{feature_name}: {scenario_name}" - # TODO: Use a WeakKeyDictionary to store the scenario object instead of attaching it to the function - scenario_wrapper.__scenario__ = templated_scenario + scenario_wrapper_template_registry[scenario_wrapper] = templated_scenario return cast(Callable[P, T], scenario_wrapper) return decorator @@ -359,9 +361,9 @@ def scenarios(*feature_paths: str, **kwargs: Any) -> None: found = False module_scenarios = frozenset( - (attr.__scenario__.feature.filename, attr.__scenario__.name) + (s.feature.filename, s.name) for name, attr in caller_locals.items() - if hasattr(attr, "__scenario__") + if (s := registry_get_safe(scenario_wrapper_template_registry, attr)) is not None ) for feature in get_features(abs_feature_paths): diff --git a/src/pytest_bdd/utils.py b/src/pytest_bdd/utils.py index 067e8d81..52aeeb91 100644 --- a/src/pytest_bdd/utils.py +++ b/src/pytest_bdd/utils.py @@ -7,6 +7,7 @@ from inspect import getframeinfo, signature from sys import _getframe from typing import TYPE_CHECKING, TypeVar, cast +from weakref import WeakKeyDictionary if TYPE_CHECKING: from typing import Any, Callable @@ -82,3 +83,13 @@ def setdefault(obj: object, name: str, default: T) -> T: except AttributeError: setattr(obj, name, default) return default + + +def registry_get_safe(registry: WeakKeyDictionary[Any, T], key: Any, default=None) -> T | None: + """Get a value from a registry, or None if the key is not in the registry. + It ensures that this works even if the key cannot be weak-referenced (normally this would raise a TypeError). + """ + try: + return registry.get(key, default) + except TypeError: + return None diff --git a/tests/feature/test_description.py b/tests/feature/test_description.py index 5d0dcb96..c6f637fc 100644 --- a/tests/feature/test_description.py +++ b/tests/feature/test_description.py @@ -33,6 +33,7 @@ def test_description(pytester): """\ import textwrap from pytest_bdd import given, scenario + from pytest_bdd.scenario import scenario_wrapper_template_registry @scenario("description.feature", "Description") def test_description(): @@ -44,7 +45,8 @@ def _(): return "bar" def test_feature_description(): - assert test_description.__scenario__.feature.description == textwrap.dedent( + scenario = scenario_wrapper_template_registry[test_description] + assert scenario.feature.description == textwrap.dedent( \"\"\"\\ In order to achieve something I want something @@ -55,7 +57,8 @@ def test_feature_description(): ) def test_scenario_description(): - assert test_description.__scenario__.description == textwrap.dedent( + scenario = scenario_wrapper_template_registry[test_description] + assert scenario.description == textwrap.dedent( \"\"\"\\ Also, the scenario can have a description.