Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run all sync/async fixture/test code under the same contextvars.Context #618

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
225 changes: 225 additions & 0 deletions src/anyio/_run_in_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""
These utility functions exist to support anyio pytest plugin's context management

Do *not* use them outside of it, for the following reasons:

* context_wrap/context_wrap_async are expected to be used only on respectively
synchronous/asynchronous fixture and test functions, and they are not robust
in the face of unusual ways in which such functions can theoretically be written;
* The wrapping of coroutines by context_wrap_async is likely to have a noticeable
overhead compared to the way that asyncio and trio integrate Context objects in
their library APIs;
"""

from __future__ import annotations

from functools import wraps
from inspect import isasyncgenfunction, isgeneratorfunction
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
Generator,
Protocol,
TypeVar,
overload,
)

_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)
_V_co = TypeVar("_V_co", covariant=True)

if TYPE_CHECKING:
from typing_extensions import ParamSpec

_P = ParamSpec("_P")


class ContextLike(Protocol):
def run(
self, func: Callable[_P, _V_co], /, *args: _P.args, **kwargs: _P.kwargs
) -> _V_co:
raise NotImplementedError


class GeneratorWrapper(Generator[_T_co, _T_contra, _V_co]):
_context: ContextLike
_wrapped: Generator[_T_co, _T_contra, _V_co]

def __init__(
self, context: ContextLike, wrapped: Generator[_T_co, _T_contra, _V_co], /
) -> None:
self._context = context
self._wrapped = wrapped

def send(self, value: _T_contra, /) -> _T_co:
return self._context.run(self._wrapped.send, value)

@overload
def throw(
self,
typ: type[BaseException],
val: object = ...,
tb: TracebackType | None = ...,
/,
) -> _T_co:
...

@overload
def throw(
self,
val: BaseException,
_: None = ...,
tb: TracebackType | None = ...,
/,
) -> _T_co:
...

def throw(
self,
tv: type[BaseException] | BaseException,
v: object = None,
tb: TracebackType | None = None,
/,
) -> _T_co:
if isinstance(tv, BaseException):
return self._context.run(self._wrapped.throw, tv)
else:
return self._context.run(self._wrapped.throw, tv, v, tb)


class AwaitableWrapper(Awaitable[_V_co]):
_context: ContextLike
_wrapped: Awaitable[_V_co]

def __init__(self, context: ContextLike, wrapped: Awaitable[_V_co], /) -> None:
self._context = context
self._wrapped = wrapped

def __await__(self) -> Generator[Any, None, _V_co]:
generator = self._context.run(self._wrapped.__await__)
return GeneratorWrapper(self._context, generator)


class AsyncGeneratorWrapper(AsyncGenerator[_T_co, _T_contra]):
_context: ContextLike
_wrapped: AsyncGenerator[_T_co, _T_contra]

def __init__(
self, context: ContextLike, wrapped: AsyncGenerator[_T_co, _T_contra], /
) -> None:
self._context = context
self._wrapped = wrapped

def asend(self, value: _T_contra, /) -> Awaitable[_T_co]:
awaitable = self._context.run(self._wrapped.asend, value)
return AwaitableWrapper(self._context, awaitable)

@overload
def athrow(
self,
typ: type[BaseException],
val: object = ...,
tb: TracebackType | None = ...,
/,
) -> Awaitable[_T_co]:
...

@overload
def athrow(
self,
val: BaseException,
_: None = ...,
tb: TracebackType | None = ...,
/,
) -> Awaitable[_T_co]:
...

def athrow(
self,
tv: type[BaseException] | BaseException,
v: object = None,
tb: TracebackType | None = None,
/,
) -> Awaitable[_T_co]:
if isinstance(tv, BaseException):
awaitable = self._context.run(self._wrapped.athrow, tv)
else:
awaitable = self._context.run(self._wrapped.athrow, tv, v, tb)
return AwaitableWrapper(self._context, awaitable)


@overload
def context_wrap(
context: ContextLike, func: Callable[_P, Generator[_T_co, _T_contra, _V_co]], /
) -> Callable[_P, Generator[_T_co, _T_contra, _V_co]]:
...


@overload
def context_wrap(
context: ContextLike, func: Callable[_P, _V_co], /
) -> Callable[_P, _V_co]:
...


def context_wrap(context: ContextLike, func: Callable[_P, Any], /) -> Callable[_P, Any]:
if isgeneratorfunction(func):

@wraps(func)
def generator_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Generator:
generator = context.run(func, *args, **kwargs)
result = yield from GeneratorWrapper(context, generator)
return result

return generator_wrapper

else:

@wraps(func)
def func_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
result = context.run(func, *args, **kwargs)
return result

return func_wrapper


@overload
def context_wrap_async(
context: ContextLike, func: Callable[_P, AsyncGenerator[_T_co, _T_contra]], /
) -> Callable[_P, AsyncGenerator[_T_co, _T_contra]]:
...


@overload
def context_wrap_async(
context: ContextLike, func: Callable[_P, Awaitable[_V_co]], /
) -> Callable[_P, Coroutine[_T_co, _T_contra, _V_co]]:
...


def context_wrap_async(
context: ContextLike, func: Callable[_P, Any], /
) -> Callable[_P, Any]:
if isasyncgenfunction(func):

@wraps(func)
def asyncgen_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> AsyncGenerator:
asyncgen = context.run(func, *args, **kwargs)
return AsyncGeneratorWrapper(context, asyncgen)

return asyncgen_wrapper

else:

@wraps(func)
async def coroutine_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
coro = context.run(func, *args, **kwargs)
result = await AwaitableWrapper(context, coro)
return result

return coroutine_wrapper
85 changes: 81 additions & 4 deletions src/anyio/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,72 @@

from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import Context, ContextVar, copy_context
from inspect import isasyncgenfunction, iscoroutinefunction
from typing import Any, Dict, Tuple, cast

import pytest
import sniffio
from _pytest.stash import StashKey

from ._core._eventloop import get_all_backends, get_async_backend
from ._run_in_context import ContextLike, context_wrap, context_wrap_async
from .abc import TestRunner

_current_runner: TestRunner | None = None
_current_reentrancy_token = ContextVar[object]("anyio.pytest_plugin.reentrancy_token")
contextvars_context_key: StashKey[Context] = StashKey()
_test_context_like_key: StashKey[ContextLike] = StashKey()


class _TestContext(ContextLike):
"""Manages reentrancy and transmission of sniffio.current_async_library_cvar"""

def __init__(self, context: Context):
self._context = context
self._reentrancy_token = object()

def _is_already_in_context(self) -> bool:
# if context var is not set to the token, we are in another context
if _current_reentrancy_token.get(None) is not self._reentrancy_token:
return False

# Token value is the same, but we may be in a copy of self._context
test_value = object()
reset_reentrancy = _current_reentrancy_token.set(test_value)
try:
return self._context[_current_reentrancy_token] is test_value
finally:
_current_reentrancy_token.reset(reset_reentrancy)

def run(self, func: Any, /, *args: Any, **kwargs: Any) -> Any:
if self._is_already_in_context():
return func(*args, **kwargs)

return self._context.run(
self._set_context_and_run,
sniffio.current_async_library_cvar.get(None),
func,
*args,
**kwargs,
)

def _set_context_and_run(
self, current_async_library: str | None, func: Any, /, *args: Any, **kwargs: Any
) -> Any:
reset_reentrancy = _current_reentrancy_token.set(self._reentrancy_token)
reset_sniffio = None
if current_async_library is not None:
reset_sniffio = sniffio.current_async_library_cvar.set(
current_async_library
)

try:
return func(*args, **kwargs)
finally:
_current_reentrancy_token.reset(reset_reentrancy)
if reset_sniffio is not None:
sniffio.current_async_library_cvar.reset(reset_sniffio)


def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]:
Expand Down Expand Up @@ -59,27 +115,41 @@ def pytest_configure(config: Any) -> None:
)


def pytest_sessionstart(session: Any) -> None:
context = copy_context()
session.stash[contextvars_context_key] = context
session.stash[_test_context_like_key] = _TestContext(context)


def pytest_fixture_setup(fixturedef: Any, request: Any) -> None:
context_like: ContextLike = request.session.stash[_test_context_like_key]

def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def]
backend_name, backend_options = extract_backend_and_options(anyio_backend)
if has_backend_arg:
kwargs["anyio_backend"] = anyio_backend

with get_runner(backend_name, backend_options) as runner:
context_wrapped = context_wrap_async(context_like, func)
if isasyncgenfunction(func):
yield from runner.run_asyncgen_fixture(func, kwargs)
yield from runner.run_asyncgen_fixture(context_wrapped, kwargs)
else:
yield runner.run_fixture(func, kwargs)
yield runner.run_fixture(context_wrapped, kwargs)

# Only apply this to coroutine functions and async generator functions in requests
# that involve the anyio_backend fixture
func = fixturedef.func
if isasyncgenfunction(func) or iscoroutinefunction(func):
if "anyio_backend" in request.fixturenames:
has_backend_arg = "anyio_backend" in fixturedef.argnames
setattr(wrapper, "_runs_in_session_context", True)
fixturedef.func = wrapper
if not has_backend_arg:
fixturedef.argnames += ("anyio_backend",)
elif not getattr(func, "_runs_in_session_context", False):
wrapper = context_wrap(context_like, func)
setattr(wrapper, "_runs_in_session_context", True)
fixturedef.func = wrapper


@pytest.hookimpl(tryfirst=True)
Expand All @@ -95,9 +165,12 @@ def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None:

@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
context_like: ContextLike = pyfuncitem.session.stash[_test_context_like_key]

def run_with_hypothesis(**kwargs: Any) -> None:
with get_runner(backend_name, backend_options) as runner:
runner.run_test(original_func, kwargs)
context_wrapped = context_wrap_async(context_like, original_func)
runner.run_test(context_wrapped, kwargs)

backend = pyfuncitem.funcargs.get("anyio_backend")
if backend:
Expand All @@ -116,10 +189,14 @@ def run_with_hypothesis(**kwargs: Any) -> None:
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
with get_runner(backend_name, backend_options) as runner:
runner.run_test(pyfuncitem.obj, testargs)
context_wrapped = context_wrap_async(context_like, pyfuncitem.obj)
runner.run_test(context_wrapped, testargs)

return True

if not iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj = context_wrap(context_like, pyfuncitem.obj)

return None


Expand Down
Loading