Skip to content

Commit

Permalink
Merge branch 'master' into subinterpreters
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm authored Jan 3, 2025
2 parents 1dc2499 + 8b7a535 commit 725d93b
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 100 deletions.
3 changes: 3 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
3.13 and later
- Added support for the ``copy()``, ``copy_into()``, ``move()`` and ``move_into()``
methods in ``anyio.Path``, available in Python 3.14
- Changed ``TaskGroup`` on asyncio to always spawn tasks non-eagerly, even if using a
task factory created via ``asyncio.create_eager_task_factory()``, to preserve expected
Trio-like task scheduling semantics (PR by @agronholm and @graingert)
- Configure ``SO_RCVBUF``, ``SO_SNDBUF`` and ``TCP_NODELAY`` on the selector
thread waker socket pair. This should improve the performance of ``wait_readable()``
and ``wait_writable()`` when using the ``ProactorEventLoop``
Expand Down
83 changes: 23 additions & 60 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
Collection,
Coroutine,
Iterable,
Iterator,
MutableMapping,
Sequence,
)
from concurrent.futures import Future
Expand All @@ -49,7 +47,7 @@
from signal import Signals
from socket import AddressFamily, SocketKind
from threading import Thread
from types import TracebackType
from types import CodeType, TracebackType
from typing import (
IO,
TYPE_CHECKING,
Expand Down Expand Up @@ -677,47 +675,7 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.cancel_scope = cancel_scope


class TaskStateStore(
MutableMapping["Coroutine[Any, Any, Any] | asyncio.Task", TaskState]
):
def __init__(self) -> None:
self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
self._preliminary_task_states: dict[Coroutine[Any, Any, Any], TaskState] = {}

def __getitem__(self, key: Coroutine[Any, Any, Any] | asyncio.Task, /) -> TaskState:
task = cast(asyncio.Task, key)
try:
return self._task_states[task]
except KeyError:
if coro := task.get_coro():
if state := self._preliminary_task_states.get(coro):
return state

raise KeyError(key)

def __setitem__(
self, key: asyncio.Task | Coroutine[Any, Any, Any], value: TaskState, /
) -> None:
if isinstance(key, Coroutine):
self._preliminary_task_states[key] = value
else:
self._task_states[key] = value

def __delitem__(self, key: asyncio.Task | Coroutine[Any, Any, Any], /) -> None:
if isinstance(key, Coroutine):
del self._preliminary_task_states[key]
else:
del self._task_states[key]

def __len__(self) -> int:
return len(self._task_states) + len(self._preliminary_task_states)

def __iter__(self) -> Iterator[Coroutine[Any, Any, Any] | asyncio.Task]:
yield from self._task_states
yield from self._preliminary_task_states


_task_states = TaskStateStore()
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()


#
Expand Down Expand Up @@ -763,6 +721,12 @@ def on_completion(task: asyncio.Task[object]) -> None:
tasks.pop().remove_done_callback(on_completion)


if sys.version_info >= (3, 12):
_eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__
else:
_eager_task_factory_code = None


class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
Expand Down Expand Up @@ -837,7 +801,7 @@ def _spawn(
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
# task_state = _task_states[_task]
task_state = _task_states[_task]
assert task_state.cancel_scope is not None
assert _task in task_state.cancel_scope._tasks
task_state.cancel_scope._tasks.remove(_task)
Expand Down Expand Up @@ -894,26 +858,25 @@ def task_done(_task: asyncio.Task) -> None:
f"the return value ({coro!r}) is not a coroutine object"
)

# Make the spawned task inherit the task group's cancel scope
_task_states[coro] = task_state = TaskState(
parent_id=parent_id, cancel_scope=self.cancel_scope
)
name = get_callable_name(func) if name is None else str(name)
try:
loop = asyncio.get_running_loop()
if (
(factory := loop.get_task_factory())
and getattr(factory, "__code__", None) is _eager_task_factory_code
and (closure := getattr(factory, "__closure__", None))
):
custom_task_constructor = closure[0].cell_contents
task = custom_task_constructor(coro, loop=loop, name=name)
else:
task = create_task(coro, name=name)
finally:
del _task_states[coro]

_task_states[task] = task_state
# Make the spawned task inherit the task group's cancel scope
_task_states[task] = TaskState(
parent_id=parent_id, cancel_scope=self.cancel_scope
)
self.cancel_scope._tasks.add(task)
self._tasks.add(task)

if task.done():
# This can happen with eager task factories
task_done(task)
else:
task.add_done_callback(task_done)

task.add_done_callback(task_done)
return task

def start_soon(
Expand Down
36 changes: 23 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import ssl
import sys
from collections.abc import Generator
from ssl import SSLContext
from typing import Any
Expand All @@ -28,21 +29,30 @@

pytest_plugins = ["pytester"]


@pytest.fixture(
params=[
pytest.param(
("asyncio", {"debug": True, "loop_factory": None}),
id="asyncio",
),
asyncio_params = [
pytest.param(("asyncio", {"debug": True}), id="asyncio"),
pytest.param(
("asyncio", {"debug": True, "loop_factory": uvloop.new_event_loop}),
marks=uvloop_marks,
id="asyncio+uvloop",
),
]
if sys.version_info >= (3, 12):

def eager_task_loop_factory() -> asyncio.AbstractEventLoop:
loop = asyncio.new_event_loop()
loop.set_task_factory(asyncio.eager_task_factory)
return loop

asyncio_params.append(
pytest.param(
("asyncio", {"debug": True, "loop_factory": uvloop.new_event_loop}),
marks=uvloop_marks,
id="asyncio+uvloop",
("asyncio", {"debug": True, "loop_factory": eager_task_loop_factory}),
id="asyncio+eager",
),
pytest.param("trio"),
]
)
)


@pytest.fixture(params=[*asyncio_params, pytest.param("trio")])
def anyio_backend(request: SubRequest) -> tuple[str, dict[str, Any]]:
return request.param

Expand Down
4 changes: 3 additions & 1 deletion tests/streams/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
MemoryObjectSendStream,
)

from ..conftest import asyncio_params

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

Expand Down Expand Up @@ -486,7 +488,7 @@ async def test_not_closed_warning() -> None:
gc.collect()


@pytest.mark.parametrize("anyio_backend", ["asyncio"], indirect=True)
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_send_to_natively_cancelled_receiver() -> None:
"""
Test that if a task waiting on receive.receive() is cancelled and then another
Expand Down
4 changes: 3 additions & 1 deletion tests/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
)
from anyio.abc import TaskStatus

from .conftest import asyncio_params

pytestmark = pytest.mark.anyio


Expand Down Expand Up @@ -127,7 +129,7 @@ def generator_part() -> Generator[object, BaseException, None]:
asyncio_event_loop.run_until_complete(native_coro_part())


@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_wait_all_tasks_blocked_asend(anyio_backend: str) -> None:
"""Test that wait_all_tasks_blocked() does not crash on an `asend()` object."""

Expand Down
4 changes: 3 additions & 1 deletion tests/test_from_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from anyio.from_thread import BlockingPortal, start_blocking_portal
from anyio.lowlevel import checkpoint

from .conftest import asyncio_params

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

Expand Down Expand Up @@ -595,7 +597,7 @@ async def get_var() -> int:

assert propagated_value == 6

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_asyncio_run_sync_called(self, caplog: LogCaptureFixture) -> None:
"""Regression test for #357."""

Expand Down
13 changes: 10 additions & 3 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io
import os
import platform
import re
import socket
import sys
import tempfile
Expand Down Expand Up @@ -61,6 +62,8 @@
from anyio.lowlevel import checkpoint
from anyio.streams.stapled import MultiListener

from .conftest import asyncio_params

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

Expand Down Expand Up @@ -488,7 +491,7 @@ def serve() -> None:
thread.join()
assert thread_exception is None

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_unretrieved_future_exception_server_crash(
self, family: AnyIPAddressFamily, caplog: LogCaptureFixture
) -> None:
Expand All @@ -497,7 +500,6 @@ async def test_unretrieved_future_exception_server_crash(
retrieved.
See https://github.com/encode/httpcore/issues/382 for details.
"""

def serve() -> None:
Expand All @@ -523,7 +525,12 @@ def serve() -> None:

thread.join()
gc.collect()
assert not caplog.text
caplog_text = "\n".join(
msg
for msg in caplog.messages
if not re.search("took [0-9.]+ seconds", msg)
)
assert not caplog_text


@pytest.mark.network
Expand Down
12 changes: 7 additions & 5 deletions tests/test_synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
)
from anyio.abc import CapacityLimiter, TaskStatus

from .conftest import asyncio_params

pytestmark = pytest.mark.anyio


Expand Down Expand Up @@ -162,7 +164,7 @@ async def waiter() -> None:
assert not lock.statistics().locked
assert lock.statistics().tasks_waiting == 0

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_asyncio_deadlock(self) -> None:
"""Regression test for #398."""
lock = Lock()
Expand All @@ -178,7 +180,7 @@ async def acquire() -> None:
task1.cancel()
await asyncio.wait_for(task2, 1)

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_cancel_after_release(self) -> None:
"""
Test that a native asyncio cancellation will not cause a lock ownership
Expand Down Expand Up @@ -565,7 +567,7 @@ async def test_acquire_race(self) -> None:
semaphore.release()
pytest.raises(WouldBlock, semaphore.acquire_nowait)

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_asyncio_deadlock(self) -> None:
"""Regression test for #398."""
semaphore = Semaphore(1)
Expand All @@ -581,7 +583,7 @@ async def acquire() -> None:
task1.cancel()
await asyncio.wait_for(task2, 1)

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_cancel_after_release(self) -> None:
"""
Test that a native asyncio cancellation will not cause a semaphore ownership
Expand Down Expand Up @@ -731,7 +733,7 @@ async def waiter() -> None:
assert limiter.statistics().tasks_waiting == 0
assert limiter.statistics().borrowed_tokens == 0

@pytest.mark.parametrize("anyio_backend", ["asyncio"])
@pytest.mark.parametrize("anyio_backend", asyncio_params)
async def test_asyncio_deadlock(self) -> None:
"""Regression test for #398."""
limiter = CapacityLimiter(1)
Expand Down
Loading

0 comments on commit 725d93b

Please sign in to comment.