From cd87193e52c9e8c42312dc34eac43ae68760d1c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Fri, 3 Jan 2025 00:28:51 +0200 Subject: [PATCH] Changed TaskGroup to always spawn tasks lazily, even with eager task factories --- docs/versionhistory.rst | 3 ++ src/anyio/_backends/_asyncio.py | 83 +++++++++------------------------ tests/conftest.py | 36 ++++++++------ tests/streams/test_memory.py | 4 +- tests/test_debugging.py | 4 +- tests/test_from_thread.py | 4 +- tests/test_sockets.py | 19 ++++++-- tests/test_synchronization.py | 12 +++-- tests/test_taskgroups.py | 50 ++++++++++++++------ tests/test_to_thread.py | 4 +- 10 files changed, 118 insertions(+), 101 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 35fd7b8a..588eebbf 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -7,6 +7,9 @@ This library adheres to `Semantic Versioning 2.0 `_. - Added support for the ``copy()``, ``copy_into()``, ``move()`` and ``move_into()`` methods in ``anyio.Path``, available in Python 3.14 +- Changed ``TaskGroup`` on asyncio to always spawn tasks non-eagerly, even if using a + task factory created via ``asyncio.create_eager_task_factory()``, to preserve expected + Trio-like task scheduling semantics (PR by @agronholm and @graingert) - Configure ``SO_RCVBUF``, ``SO_SNDBUF`` and ``TCP_NODELAY`` on the selector thread waker socket pair. This should improve the performance of ``wait_readable()`` and ``wait_writable()`` when using the ``ProactorEventLoop`` diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 11582529..5122d7c5 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -28,8 +28,6 @@ Collection, Coroutine, Iterable, - Iterator, - MutableMapping, Sequence, ) from concurrent.futures import Future @@ -49,7 +47,7 @@ from signal import Signals from socket import AddressFamily, SocketKind from threading import Thread -from types import TracebackType +from types import CodeType, TracebackType from typing import ( IO, TYPE_CHECKING, @@ -677,47 +675,7 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): self.cancel_scope = cancel_scope -class TaskStateStore( - MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState] -): - def __init__(self) -> None: - self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]() - self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {} - - def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState: - task = cast(asyncio.Task, key) - try: - return self._task_states[task] - except KeyError: - if coro := task.get_coro(): - if state := self._preliminary_task_states.get(coro): - return state - - raise KeyError(key) - - def __setitem__( - self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, / - ) -> None: - if isinstance(key, Coroutine): - self._preliminary_task_states[key] = value - else: - self._task_states[key] = value - - def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None: - if isinstance(key, Coroutine): - del self._preliminary_task_states[key] - else: - del self._task_states[key] - - def __len__(self) -> int: - return len(self._task_states) + len(self._preliminary_task_states) - - def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]: - yield from self._task_states - yield from self._preliminary_task_states - - -_task_states = TaskStateStore() +_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary() # @@ -763,6 +721,12 @@ def on_completion(task: asyncio.Task[object]) -> None: tasks.pop().remove_done_callback(on_completion) +if sys.version_info >= (3, 12): + _eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__ +else: + _eager_task_factory_code = None + + class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self.cancel_scope: CancelScope = CancelScope() @@ -837,7 +801,7 @@ def _spawn( task_status_future: asyncio.Future | None = None, ) -> asyncio.Task: def task_done(_task: asyncio.Task) -> None: - # task_state = _task_states[_task] + task_state = _task_states[_task] assert task_state.cancel_scope is not None assert _task in task_state.cancel_scope._tasks task_state.cancel_scope._tasks.remove(_task) @@ -894,26 +858,25 @@ def task_done(_task: asyncio.Task) -> None: f"the return value ({coro!r}) is not a coroutine object" ) - # Make the spawned task inherit the task group's cancel scope - _task_states[coro] = task_state = TaskState( - parent_id=parent_id, cancel_scope=self.cancel_scope - ) name = get_callable_name(func) if name is None else str(name) - try: + loop = asyncio.get_running_loop() + if ( + (factory := loop.get_task_factory()) + and getattr(factory, "__code__", None) is _eager_task_factory_code + and (closure := getattr(factory, "__closure__", None)) + ): + custom_task_constructor = closure[0].cell_contents + task = custom_task_constructor(coro, loop=loop, name=name) + else: task = create_task(coro, name=name) - finally: - del _task_states[coro] - _task_states[task] = task_state + # Make the spawned task inherit the task group's cancel scope + _task_states[task] = TaskState( + parent_id=parent_id, cancel_scope=self.cancel_scope + ) self.cancel_scope._tasks.add(task) self._tasks.add(task) - - if task.done(): - # This can happen with eager task factories - task_done(task) - else: - task.add_done_callback(task_done) - + task.add_done_callback(task_done) return task def start_soon( diff --git a/tests/conftest.py b/tests/conftest.py index 52998044..c29b9e80 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import asyncio import ssl +import sys from collections.abc import Generator from ssl import SSLContext from typing import Any @@ -28,21 +29,30 @@ pytest_plugins = ["pytester"] - -@pytest.fixture( - params=[ - pytest.param( - ("asyncio", {"debug": True, "loop_factory": None}), - id="asyncio", - ), +asyncio_params = [ + pytest.param(("asyncio", {"debug": True}), id="asyncio"), + pytest.param( + ("asyncio", {"debug": True, "loop_factory": uvloop.new_event_loop}), + marks=uvloop_marks, + id="asyncio+uvloop", + ), +] +if sys.version_info >= (3, 12): + + def eager_task_loop_factory() -> asyncio.AbstractEventLoop: + loop = asyncio.new_event_loop() + loop.set_task_factory(asyncio.eager_task_factory) + return loop + + asyncio_params.append( pytest.param( - ("asyncio", {"debug": True, "loop_factory": uvloop.new_event_loop}), - marks=uvloop_marks, - id="asyncio+uvloop", + ("asyncio", {"debug": True, "loop_factory": eager_task_loop_factory}), + id="asyncio+eager", ), - pytest.param("trio"), - ] -) + ) + + +@pytest.fixture(params=[*asyncio_params, pytest.param("trio")]) def anyio_backend(request: SubRequest) -> tuple[str, dict[str, Any]]: return request.param diff --git a/tests/streams/test_memory.py b/tests/streams/test_memory.py index 0e6d022a..4a4adbdd 100644 --- a/tests/streams/test_memory.py +++ b/tests/streams/test_memory.py @@ -24,6 +24,8 @@ MemoryObjectSendStream, ) +from ..conftest import asyncio_params + if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -486,7 +488,7 @@ async def test_not_closed_warning() -> None: gc.collect() -@pytest.mark.parametrize("anyio_backend", ["asyncio"], indirect=True) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_send_to_natively_cancelled_receiver() -> None: """ Test that if a task waiting on receive.receive() is cancelled and then another diff --git a/tests/test_debugging.py b/tests/test_debugging.py index 72843988..7813eaac 100644 --- a/tests/test_debugging.py +++ b/tests/test_debugging.py @@ -19,6 +19,8 @@ ) from anyio.abc import TaskStatus +from .conftest import asyncio_params + pytestmark = pytest.mark.anyio @@ -127,7 +129,7 @@ def generator_part() -> Generator[object, BaseException, None]: asyncio_event_loop.run_until_complete(native_coro_part()) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_wait_all_tasks_blocked_asend(anyio_backend: str) -> None: """Test that wait_all_tasks_blocked() does not crash on an `asend()` object.""" diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index e4c29ce0..009edd0f 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -33,6 +33,8 @@ from anyio.from_thread import BlockingPortal, start_blocking_portal from anyio.lowlevel import checkpoint +from .conftest import asyncio_params + if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -595,7 +597,7 @@ async def get_var() -> int: assert propagated_value == 6 - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_run_sync_called(self, caplog: LogCaptureFixture) -> None: """Regression test for #357.""" diff --git a/tests/test_sockets.py b/tests/test_sockets.py index b5143df0..f07bae50 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -5,6 +5,7 @@ import io import os import platform +import re import socket import sys import tempfile @@ -61,6 +62,8 @@ from anyio.lowlevel import checkpoint from anyio.streams.stapled import MultiListener +from .conftest import asyncio_params + if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -488,16 +491,19 @@ def serve() -> None: thread.join() assert thread_exception is None - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.fixture + def gc_collect(self) -> None: + gc.collect() + + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_unretrieved_future_exception_server_crash( - self, family: AnyIPAddressFamily, caplog: LogCaptureFixture + self, family: AnyIPAddressFamily, caplog: LogCaptureFixture, gc_collect: None ) -> None: """ Test that there won't be any leftover Futures that don't get their exceptions retrieved. See https://github.com/encode/httpcore/issues/382 for details. - """ def serve() -> None: @@ -523,7 +529,12 @@ def serve() -> None: thread.join() gc.collect() - assert not caplog.text + caplog_text = "\n".join( + msg + for msg in caplog.messages + if not re.search("took [0-9.]+ seconds", msg) + ) + assert not caplog_text @pytest.mark.network diff --git a/tests/test_synchronization.py b/tests/test_synchronization.py index 83758c62..92a7a5a2 100644 --- a/tests/test_synchronization.py +++ b/tests/test_synchronization.py @@ -20,6 +20,8 @@ ) from anyio.abc import CapacityLimiter, TaskStatus +from .conftest import asyncio_params + pytestmark = pytest.mark.anyio @@ -162,7 +164,7 @@ async def waiter() -> None: assert not lock.statistics().locked assert lock.statistics().tasks_waiting == 0 - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_deadlock(self) -> None: """Regression test for #398.""" lock = Lock() @@ -178,7 +180,7 @@ async def acquire() -> None: task1.cancel() await asyncio.wait_for(task2, 1) - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_after_release(self) -> None: """ Test that a native asyncio cancellation will not cause a lock ownership @@ -565,7 +567,7 @@ async def test_acquire_race(self) -> None: semaphore.release() pytest.raises(WouldBlock, semaphore.acquire_nowait) - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_deadlock(self) -> None: """Regression test for #398.""" semaphore = Semaphore(1) @@ -581,7 +583,7 @@ async def acquire() -> None: task1.cancel() await asyncio.wait_for(task2, 1) - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_after_release(self) -> None: """ Test that a native asyncio cancellation will not cause a semaphore ownership @@ -731,7 +733,7 @@ async def waiter() -> None: assert limiter.statistics().tasks_waiting == 0 assert limiter.statistics().borrowed_tokens == 0 - @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + @pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_deadlock(self) -> None: """Regression test for #398.""" limiter = CapacityLimiter(1) diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 1c1a654c..52f96600 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -32,6 +32,8 @@ from anyio.abc import TaskGroup, TaskStatus from anyio.lowlevel import checkpoint +from .conftest import asyncio_params + if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -200,7 +202,7 @@ async def taskfunc(*, task_status: TaskStatus) -> None: assert not finished -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_start_native_host_cancelled() -> None: started = finished = False @@ -224,7 +226,7 @@ async def start_another() -> None: assert not finished -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_start_native_child_cancelled() -> None: task = None finished = False @@ -248,7 +250,7 @@ async def start_another() -> None: assert not finished -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_propagate_native_cancellation_from_taskgroup() -> None: async def taskfunc() -> None: async with create_task_group() as tg: @@ -261,7 +263,7 @@ async def taskfunc() -> None: await task -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_with_nested_task_groups() -> None: """Regression test for #695.""" @@ -703,7 +705,7 @@ async def test_shielded_cleanup_after_cancel() -> None: assert get_current_task().has_pending_cancellation() -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cleanup_after_native_cancel() -> None: """Regression test for #832.""" # See also https://github.com/python/cpython/pull/102815. @@ -803,7 +805,7 @@ async def outer_task() -> None: assert outer_task_ran -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_host_asyncgen() -> None: done = False @@ -1159,7 +1161,7 @@ def generator_part() -> Generator[object, BaseException, None]: @pytest.mark.filterwarnings( 'ignore:"@coroutine" decorator is deprecated:DeprecationWarning' ) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_schedule_old_style_coroutine_func() -> None: """ Test that we give a sensible error when a user tries to spawn a task from a @@ -1182,7 +1184,7 @@ def corofunc() -> Generator[Any, Any, None]: tg.start_soon(corofunc) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_native_future_tasks() -> None: async def wait_native_future() -> None: loop = asyncio.get_running_loop() @@ -1193,7 +1195,7 @@ async def wait_native_future() -> None: tg.cancel_scope.cancel() -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_native_future_tasks_cancel_scope() -> None: async def wait_native_future() -> None: with anyio.CancelScope(): @@ -1205,7 +1207,7 @@ async def wait_native_future() -> None: tg.cancel_scope.cancel() -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_cancel_completed_task() -> None: loop = asyncio.get_running_loop() old_exception_handler = loop.get_exception_handler() @@ -1301,7 +1303,7 @@ async def test_cancelscope_exit_before_enter() -> None: @pytest.mark.parametrize( - "anyio_backend", ["asyncio"] + "anyio_backend", asyncio_params ) # trio does not check for this yet async def test_cancelscope_exit_in_wrong_task() -> None: async def enter_scope(scope: CancelScope) -> None: @@ -1416,7 +1418,7 @@ async def starter_task() -> None: sys.version_info < (3, 11), reason="Task uncancelling is only supported on Python 3.11", ) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) class TestUncancel: async def test_uncancel_after_native_cancel(self) -> None: task = cast(asyncio.Task, asyncio.current_task()) @@ -1792,23 +1794,41 @@ async def typetest_optional_status( reason="Eager task factories require Python 3.12", ) @pytest.mark.parametrize("anyio_backend", ["asyncio"]) -async def test_eager_task_factory(request: FixtureRequest) -> None: +@pytest.mark.parametrize("use_custom_eager_factory", [False, True]) +async def test_eager_task_factory( + request: FixtureRequest, use_custom_eager_factory: bool +) -> None: + ran = False + async def sync_coro() -> None: + nonlocal ran + ran = True + # This should trigger fetching the task state with CancelScope(): # noqa: ASYNC100 pass + def create_custom_task( + coro: Coroutine[Any, Any, Any], /, **kwargs: Any + ) -> asyncio.Task[Any]: + return asyncio.Task(coro, **kwargs) + loop = asyncio.get_running_loop() old_task_factory = loop.get_task_factory() - loop.set_task_factory(asyncio.eager_task_factory) + if use_custom_eager_factory: + loop.set_task_factory(asyncio.create_eager_task_factory(create_custom_task)) + else: + loop.set_task_factory(asyncio.eager_task_factory) + request.addfinalizer(lambda: loop.set_task_factory(old_task_factory)) async with create_task_group() as tg: tg.start_soon(sync_coro) + assert not ran tg.cancel_scope.cancel() -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_patched_asyncio_task(monkeypatch: MonkeyPatch) -> None: monkeypatch.setattr( asyncio, diff --git a/tests/test_to_thread.py b/tests/test_to_thread.py index 9b80de2d..caffa275 100644 --- a/tests/test_to_thread.py +++ b/tests/test_to_thread.py @@ -23,6 +23,8 @@ ) from anyio.from_thread import BlockingPortalProvider +from .conftest import asyncio_params + pytestmark = pytest.mark.anyio @@ -159,7 +161,7 @@ async def test_asynclib_detection() -> None: await to_thread.run_sync(sniffio.current_async_library) -@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +@pytest.mark.parametrize("anyio_backend", asyncio_params) async def test_asyncio_cancel_native_task() -> None: task: asyncio.Task[None] | None = None