Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed TaskGroup to always spawn tasks lazily with eager task factories #853

Merged
merged 1 commit into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

- Added support for the ``copy()``, ``copy_into()``, ``move()`` and ``move_into()``
methods in ``anyio.Path``, available in Python 3.14
- Changed ``TaskGroup`` on asyncio to always spawn tasks non-eagerly, even if using a
task factory created via ``asyncio.create_eager_task_factory()``, to preserve expected
Trio-like task scheduling semantics (PR by @agronholm and @graingert)
- Configure ``SO_RCVBUF``, ``SO_SNDBUF`` and ``TCP_NODELAY`` on the selector
thread waker socket pair. This should improve the performance of ``wait_readable()``
and ``wait_writable()`` when using the ``ProactorEventLoop``
Expand Down
83 changes: 23 additions & 60 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
Collection,
Coroutine,
Iterable,
Iterator,
MutableMapping,
Sequence,
)
from concurrent.futures import Future
Expand All @@ -49,7 +47,7 @@
from signal import Signals
from socket import AddressFamily, SocketKind
from threading import Thread
from types import TracebackType
from types import CodeType, TracebackType
from typing import (
IO,
TYPE_CHECKING,
Expand Down Expand Up @@ -677,47 +675,7 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.cancel_scope = cancel_scope


class TaskStateStore(
MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState]
):
def __init__(self) -> None:
self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {}

def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState:
task = cast(asyncio.Task, key)
try:
return self._task_states[task]
except KeyError:
if coro := task.get_coro():
if state := self._preliminary_task_states.get(coro):
return state

raise KeyError(key)

def __setitem__(
self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, /
) -> None:
if isinstance(key, Coroutine):
self._preliminary_task_states[key] = value
else:
self._task_states[key] = value

def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None:
if isinstance(key, Coroutine):
del self._preliminary_task_states[key]
else:
del self._task_states[key]

def __len__(self) -> int:
return len(self._task_states) + len(self._preliminary_task_states)

def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]:
yield from self._task_states
yield from self._preliminary_task_states


_task_states = TaskStateStore()
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()


#
Expand Down Expand Up @@ -763,6 +721,12 @@ def on_completion(task: asyncio.Task[object]) -> None:
tasks.pop().remove_done_callback(on_completion)


if sys.version_info >= (3, 12):
_eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__
else:
_eager_task_factory_code = None


class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
Expand Down Expand Up @@ -837,7 +801,7 @@ def _spawn(
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
# task_state = _task_states[_task]
task_state = _task_states[_task]
assert task_state.cancel_scope is not None
assert _task in task_state.cancel_scope._tasks
task_state.cancel_scope._tasks.remove(_task)
Expand Down Expand Up @@ -894,26 +858,25 @@ def task_done(_task: asyncio.Task) -> None:
f"the return value ({coro!r}) is not a coroutine object"
)

# Make the spawned task inherit the task group's cancel scope
_task_states[coro] = task_state = TaskState(
parent_id=parent_id, cancel_scope=self.cancel_scope
)
name = get_callable_name(func) if name is None else str(name)
try:
loop = asyncio.get_running_loop()
if (
(factory := loop.get_task_factory())
and getattr(factory, "__code__", None) is _eager_task_factory_code
and (closure := getattr(factory, "__closure__", None))
):
custom_task_constructor = closure[0].cell_contents
task = custom_task_constructor(coro, loop=loop, name=name)
else:
task = create_task(coro, name=name)
finally:
del _task_states[coro]

_task_states[task] = task_state
# Make the spawned task inherit the task group's cancel scope
_task_states[task] = TaskState(
parent_id=parent_id, cancel_scope=self.cancel_scope
)
self.cancel_scope._tasks.add(task)
self._tasks.add(task)

if task.done():
# This can happen with eager task factories
task_done(task)
else:
task.add_done_callback(task_done)

task.add_done_callback(task_done)
return task

def start_soon(
Expand Down
36 changes: 23 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import ssl
import sys
from collections.abc import Generator
from ssl import SSLContext
from typing import Any
Expand All @@ -28,21 +29,30 @@

pytest_plugins = ["pytester"]


@pytest.fixture(
params=[
pytest.param(
("asyncio", {"debug": True, "loop_factory": None}),
id="asyncio",
),
asyncio_params = [
pytest.param(("asyncio", {"debug": True}), id="asyncio"),
pytest.param(
("asyncio", {"debug": True, "loop_factory": uvloop.new_event_loop}),
marks=uvloop_marks,
id="asyncio+uvloop",
),
]
if sys.version_info >= (3, 12):

def eager_task_loop_factory() -> asyncio.AbstractEventLoop:
loop = asyncio.new_event_loop()
loop.set_task_factory(asyncio.eager_task_factory)
return loop

asyncio_params.append(
pytest.param(
("asyncio", {"debug": True, "loop_factory": uvloop.new_event_loop}),
marks=uvloop_marks,
id="asyncio+uvloop",
("asyncio", {"debug": True, "loop_factory": eager_task_loop_factory}),
id="asyncio+eager",
),
pytest.param("trio"),
]
)
)


@pytest.fixture(params=[*asyncio_params, pytest.param("trio")])
def anyio_backend(request: SubRequest) -> tuple[str, dict[str, Any]]:
return request.param

Expand Down
4 changes: 3 additions & 1 deletion tests/streams/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
MemoryObjectSendStream,
)

from ..conftest import asyncio_params

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

Expand Down Expand Up @@ -486,7 +488,7 @@ async def test_not_closed_warning() -> None:
gc.collect()


@pytest.mark.parametrize("anyio_backend", ["asyncio"], indirect=True)
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_send_to_natively_cancelled_receiver() -> None:
"""
Test that if a task waiting on receive.receive() is cancelled and then another
Expand Down
4 changes: 3 additions & 1 deletion tests/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
)
from anyio.abc import TaskStatus

from .conftest import asyncio_params

pytestmark = pytest.mark.anyio


Expand Down Expand Up @@ -127,7 +129,7 @@ def generator_part() -> Generator[object, BaseException, None]:
asyncio_event_loop.run_until_complete(native_coro_part())


@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_wait_all_tasks_blocked_asend(anyio_backend: str) -> None:
"""Test that wait_all_tasks_blocked() does not crash on an `asend()` object."""

Expand Down
4 changes: 3 additions & 1 deletion tests/test_from_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from anyio.from_thread import BlockingPortal, start_blocking_portal
from anyio.lowlevel import checkpoint

from .conftest import asyncio_params

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

Expand Down Expand Up @@ -595,7 +597,7 @@ async def get_var() -> int:

assert propagated_value == 6

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_asyncio_run_sync_called(self, caplog: LogCaptureFixture) -> None:
"""Regression test for #357."""

Expand Down
19 changes: 15 additions & 4 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io
import os
import platform
import re
import socket
import sys
import tempfile
Expand Down Expand Up @@ -61,6 +62,8 @@
from anyio.lowlevel import checkpoint
from anyio.streams.stapled import MultiListener

from .conftest import asyncio_params

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

Expand Down Expand Up @@ -488,16 +491,19 @@ def serve() -> None:
thread.join()
assert thread_exception is None

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.fixture
def gc_collect(self) -> None:
gc.collect()

@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_unretrieved_future_exception_server_crash(
self, family: AnyIPAddressFamily, caplog: LogCaptureFixture
self, family: AnyIPAddressFamily, caplog: LogCaptureFixture, gc_collect: None
) -> None:
"""
Test that there won't be any leftover Futures that don't get their exceptions
retrieved.

See https://github.com/encode/httpcore/issues/382 for details.

"""

def serve() -> None:
Expand All @@ -523,7 +529,12 @@ def serve() -> None:

thread.join()
gc.collect()
assert not caplog.text
caplog_text = "\n".join(
msg
for msg in caplog.messages
if not re.search("took [0-9.]+ seconds", msg)
)
assert not caplog_text


@pytest.mark.network
Expand Down
12 changes: 7 additions & 5 deletions tests/test_synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
)
from anyio.abc import CapacityLimiter, TaskStatus

from .conftest import asyncio_params

pytestmark = pytest.mark.anyio


Expand Down Expand Up @@ -162,7 +164,7 @@ async def waiter() -> None:
assert not lock.statistics().locked
assert lock.statistics().tasks_waiting == 0

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_asyncio_deadlock(self) -> None:
"""Regression test for #398."""
lock = Lock()
Expand All @@ -178,7 +180,7 @@ async def acquire() -> None:
task1.cancel()
await asyncio.wait_for(task2, 1)

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_cancel_after_release(self) -> None:
"""
Test that a native asyncio cancellation will not cause a lock ownership
Expand Down Expand Up @@ -565,7 +567,7 @@ async def test_acquire_race(self) -> None:
semaphore.release()
pytest.raises(WouldBlock, semaphore.acquire_nowait)

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_asyncio_deadlock(self) -> None:
"""Regression test for #398."""
semaphore = Semaphore(1)
Expand All @@ -581,7 +583,7 @@ async def acquire() -> None:
task1.cancel()
await asyncio.wait_for(task2, 1)

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_cancel_after_release(self) -> None:
"""
Test that a native asyncio cancellation will not cause a semaphore ownership
Expand Down Expand Up @@ -731,7 +733,7 @@ async def waiter() -> None:
assert limiter.statistics().tasks_waiting == 0
assert limiter.statistics().borrowed_tokens == 0

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_asyncio_deadlock(self) -> None:
"""Regression test for #398."""
limiter = CapacityLimiter(1)
Expand Down
Loading
Loading