From 0c8ad519e136c4b136b51333863966a445080a97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 5 Sep 2024 09:34:20 +0300 Subject: [PATCH] Delegated the implementations of Lock and Semaphore to the async backend class (#761) Also added the `fast_acquire` parameter for Lock and Semaphore. --- docs/synchronization.rst | 13 ++ docs/versionhistory.rst | 5 + src/anyio/_backends/_asyncio.py | 181 +++++++++++++++++++- src/anyio/_backends/_trio.py | 128 +++++++++++++- src/anyio/_core/_synchronization.py | 247 ++++++++++++++++++---------- src/anyio/abc/_eventloop.py | 18 +- tests/test_synchronization.py | 120 +++++++++++++- 7 files changed, 618 insertions(+), 94 deletions(-) diff --git a/docs/synchronization.rst b/docs/synchronization.rst index 064f99c1..5b4f6d81 100644 --- a/docs/synchronization.rst +++ b/docs/synchronization.rst @@ -66,6 +66,13 @@ Example:: run(main) +.. tip:: If the performance of semaphores is critical for you, you could pass + ``fast_acquire=True`` to :class:`Semaphore`. This has the effect of skipping the + :func:`~.lowlevel.cancel_shielded_checkpoint` call in :meth:`Semaphore.acquire` if + there is no contention (acquisition succeeds immediately). This could, in some cases, + lead to the task never yielding control back to to the event loop if you use the + semaphore in a loop that does not have other yield points. + Locks ----- @@ -92,6 +99,12 @@ Example:: run(main) +.. tip:: If the performance of locks is critical for you, you could pass + ``fast_acquire=True`` to :class:`Lock`. This has the effect of skipping the + :func:`~.lowlevel.cancel_shielded_checkpoint` call in :meth:`Lock.acquire` if there + is no contention (acquisition succeeds immediately). This could, in some cases, lead + to the task never yielding control back to to the event loop if use the lock in a + loop that does not have other yield points. Conditions ---------- diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 2eb72bc7..b2d87857 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -5,6 +5,11 @@ This library adheres to `Semantic Versioning 2.0 `_. **UNRELEASED** +- Improved the performance of ``anyio.Lock`` and ``anyio.Semaphore`` on asyncio (even up + to 50 %) +- Added the ``fast_acquire`` parameter to ``anyio.Lock`` and ``anyio.Semaphore`` to + further boost performance at the expense of safety (``acquire()`` will not yield + control back if there is no contention) - Fixed ``__repr__()`` of ``MemoryObjectItemReceiver``, when ``item`` is not defined (`#767 `_; PR by @Danipulok) - Added support for the ``from_uri()``, ``full_match()``, ``parser`` methods/properties diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index b67f8e22..b88b9cc1 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -58,7 +58,13 @@ import sniffio -from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc +from .. import ( + CapacityLimiterStatistics, + EventStatistics, + LockStatistics, + TaskInfo, + abc, +) from .._core._eventloop import claim_worker_thread, threadlocals from .._core._exceptions import ( BrokenResourceError, @@ -70,9 +76,16 @@ ) from .._core._sockets import convert_ipv6_sockaddr from .._core._streams import create_memory_object_stream -from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter +from .._core._synchronization import ( + CapacityLimiter as BaseCapacityLimiter, +) from .._core._synchronization import Event as BaseEvent -from .._core._synchronization import ResourceGuard +from .._core._synchronization import Lock as BaseLock +from .._core._synchronization import ( + ResourceGuard, + SemaphoreStatistics, +) +from .._core._synchronization import Semaphore as BaseSemaphore from .._core._tasks import CancelScope as BaseCancelScope from ..abc import ( AsyncBackend, @@ -1658,6 +1671,154 @@ def statistics(self) -> EventStatistics: return EventStatistics(len(self._event._waiters)) +class Lock(BaseLock): + def __new__(cls, *, fast_acquire: bool = False) -> Lock: + return object.__new__(cls) + + def __init__(self, *, fast_acquire: bool = False) -> None: + self._fast_acquire = fast_acquire + self._owner_task: asyncio.Task | None = None + self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque() + + async def acquire(self) -> None: + if self._owner_task is None and not self._waiters: + await AsyncIOBackend.checkpoint_if_cancelled() + self._owner_task = current_task() + + # Unless on the "fast path", yield control of the event loop so that other + # tasks can run too + if not self._fast_acquire: + try: + await AsyncIOBackend.cancel_shielded_checkpoint() + except CancelledError: + self.release() + raise + + return + + task = cast(asyncio.Task, current_task()) + fut: asyncio.Future[None] = asyncio.Future() + item = task, fut + self._waiters.append(item) + try: + await fut + except CancelledError: + self._waiters.remove(item) + if self._owner_task is task: + self.release() + + raise + + self._waiters.remove(item) + + def acquire_nowait(self) -> None: + if self._owner_task is None and not self._waiters: + self._owner_task = current_task() + return + + raise WouldBlock + + def locked(self) -> bool: + return self._owner_task is not None + + def release(self) -> None: + if self._owner_task != current_task(): + raise RuntimeError("The current task is not holding this lock") + + for task, fut in self._waiters: + if not fut.cancelled(): + self._owner_task = task + fut.set_result(None) + return + + self._owner_task = None + + def statistics(self) -> LockStatistics: + task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None + return LockStatistics(self.locked(), task_info, len(self._waiters)) + + +class Semaphore(BaseSemaphore): + def __new__( + cls, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ) -> Semaphore: + return object.__new__(cls) + + def __init__( + self, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ): + super().__init__(initial_value, max_value=max_value) + self._value = initial_value + self._max_value = max_value + self._fast_acquire = fast_acquire + self._waiters: deque[asyncio.Future[None]] = deque() + + async def acquire(self) -> None: + if self._value > 0 and not self._waiters: + await AsyncIOBackend.checkpoint_if_cancelled() + self._value -= 1 + + # Unless on the "fast path", yield control of the event loop so that other + # tasks can run too + if not self._fast_acquire: + try: + await AsyncIOBackend.cancel_shielded_checkpoint() + except CancelledError: + self.release() + raise + + return + + fut: asyncio.Future[None] = asyncio.Future() + self._waiters.append(fut) + try: + await fut + except CancelledError: + try: + self._waiters.remove(fut) + except ValueError: + self.release() + + raise + + def acquire_nowait(self) -> None: + if self._value == 0: + raise WouldBlock + + self._value -= 1 + + def release(self) -> None: + if self._max_value is not None and self._value == self._max_value: + raise ValueError("semaphore released too many times") + + for fut in self._waiters: + if not fut.cancelled(): + fut.set_result(None) + self._waiters.remove(fut) + return + + self._value += 1 + + @property + def value(self) -> int: + return self._value + + @property + def max_value(self) -> int | None: + return self._max_value + + def statistics(self) -> SemaphoreStatistics: + return SemaphoreStatistics(len(self._waiters)) + + class CapacityLimiter(BaseCapacityLimiter): _total_tokens: float = 0 @@ -2108,6 +2269,20 @@ def create_task_group(cls) -> abc.TaskGroup: def create_event(cls) -> abc.Event: return Event() + @classmethod + def create_lock(cls, *, fast_acquire: bool) -> abc.Lock: + return Lock(fast_acquire=fast_acquire) + + @classmethod + def create_semaphore( + cls, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ) -> abc.Semaphore: + return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire) + @classmethod def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter: return CapacityLimiter(total_tokens) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 61205009..9b8369d4 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -45,7 +45,14 @@ from trio.socket import SocketType as TrioSocketType from trio.to_thread import run_sync -from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc +from .. import ( + CapacityLimiterStatistics, + EventStatistics, + LockStatistics, + TaskInfo, + WouldBlock, + abc, +) from .._core._eventloop import claim_worker_thread from .._core._exceptions import ( BrokenResourceError, @@ -55,9 +62,16 @@ ) from .._core._sockets import convert_ipv6_sockaddr from .._core._streams import create_memory_object_stream -from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter +from .._core._synchronization import ( + CapacityLimiter as BaseCapacityLimiter, +) from .._core._synchronization import Event as BaseEvent -from .._core._synchronization import ResourceGuard +from .._core._synchronization import Lock as BaseLock +from .._core._synchronization import ( + ResourceGuard, + SemaphoreStatistics, +) +from .._core._synchronization import Semaphore as BaseSemaphore from .._core._tasks import CancelScope as BaseCancelScope from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType from ..abc._eventloop import AsyncBackend, StrOrBytesPath @@ -637,6 +651,100 @@ def set(self) -> None: self.__original.set() +class Lock(BaseLock): + def __new__(cls, *, fast_acquire: bool = False) -> Lock: + return object.__new__(cls) + + def __init__(self, *, fast_acquire: bool = False) -> None: + self._fast_acquire = fast_acquire + self.__original = trio.Lock() + + async def acquire(self) -> None: + if not self._fast_acquire: + await self.__original.acquire() + return + + # This is the "fast path" where we don't let other tasks run + await trio.lowlevel.checkpoint_if_cancelled() + try: + self.__original.acquire_nowait() + except trio.WouldBlock: + await self.__original._lot.park() + + def acquire_nowait(self) -> None: + try: + self.__original.acquire_nowait() + except trio.WouldBlock: + raise WouldBlock from None + + def locked(self) -> bool: + return self.__original.locked() + + def release(self) -> None: + self.__original.release() + + def statistics(self) -> LockStatistics: + orig_statistics = self.__original.statistics() + owner = TrioTaskInfo(orig_statistics.owner) if orig_statistics.owner else None + return LockStatistics( + orig_statistics.locked, owner, orig_statistics.tasks_waiting + ) + + +class Semaphore(BaseSemaphore): + def __new__( + cls, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ) -> Semaphore: + return object.__new__(cls) + + def __init__( + self, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ) -> None: + super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire) + self.__original = trio.Semaphore(initial_value, max_value=max_value) + + async def acquire(self) -> None: + if not self._fast_acquire: + await self.__original.acquire() + return + + # This is the "fast path" where we don't let other tasks run + await trio.lowlevel.checkpoint_if_cancelled() + try: + self.__original.acquire_nowait() + except trio.WouldBlock: + await self.__original._lot.park() + + def acquire_nowait(self) -> None: + try: + self.__original.acquire_nowait() + except trio.WouldBlock: + raise WouldBlock from None + + @property + def max_value(self) -> int | None: + return self.__original.max_value + + @property + def value(self) -> int: + return self.__original.value + + def release(self) -> None: + self.__original.release() + + def statistics(self) -> SemaphoreStatistics: + orig_statistics = self.__original.statistics() + return SemaphoreStatistics(orig_statistics.tasks_waiting) + + class CapacityLimiter(BaseCapacityLimiter): def __new__( cls, @@ -915,6 +1023,20 @@ def create_task_group(cls) -> abc.TaskGroup: def create_event(cls) -> abc.Event: return Event() + @classmethod + def create_lock(cls, *, fast_acquire: bool) -> Lock: + return Lock(fast_acquire=fast_acquire) + + @classmethod + def create_semaphore( + cls, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ) -> abc.Semaphore: + return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire) + @classmethod def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter: return CapacityLimiter(total_tokens) diff --git a/src/anyio/_core/_synchronization.py b/src/anyio/_core/_synchronization.py index b274a31e..023ab733 100644 --- a/src/anyio/_core/_synchronization.py +++ b/src/anyio/_core/_synchronization.py @@ -7,9 +7,9 @@ from sniffio import AsyncLibraryNotFoundError -from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled +from ..lowlevel import checkpoint from ._eventloop import get_async_backend -from ._exceptions import BusyResourceError, WouldBlock +from ._exceptions import BusyResourceError from ._tasks import CancelScope from ._testing import TaskInfo, get_current_task @@ -137,10 +137,11 @@ def statistics(self) -> EventStatistics: class Lock: - _owner_task: TaskInfo | None = None - - def __init__(self) -> None: - self._waiters: deque[tuple[TaskInfo, Event]] = deque() + def __new__(cls, *, fast_acquire: bool = False) -> Lock: + try: + return get_async_backend().create_lock(fast_acquire=fast_acquire) + except AsyncLibraryNotFoundError: + return LockAdapter(fast_acquire=fast_acquire) async def __aenter__(self) -> None: await self.acquire() @@ -155,31 +156,7 @@ async def __aexit__( async def acquire(self) -> None: """Acquire the lock.""" - await checkpoint_if_cancelled() - try: - self.acquire_nowait() - except WouldBlock: - task = get_current_task() - event = Event() - token = task, event - self._waiters.append(token) - try: - await event.wait() - except BaseException: - if not event.is_set(): - self._waiters.remove(token) - elif self._owner_task == task: - self.release() - - raise - - assert self._owner_task == task - else: - try: - await cancel_shielded_checkpoint() - except BaseException: - self.release() - raise + raise NotImplementedError def acquire_nowait(self) -> None: """ @@ -188,37 +165,87 @@ def acquire_nowait(self) -> None: :raises ~anyio.WouldBlock: if the operation would block """ - task = get_current_task() - if self._owner_task == task: - raise RuntimeError("Attempted to acquire an already held Lock") + raise NotImplementedError + + def release(self) -> None: + """Release the lock.""" + raise NotImplementedError + + def locked(self) -> bool: + """Return True if the lock is currently held.""" + raise NotImplementedError + + def statistics(self) -> LockStatistics: + """ + Return statistics about the current state of this lock. + + .. versionadded:: 3.0 + """ + raise NotImplementedError + + +class LockAdapter(Lock): + _internal_lock: Lock | None = None + + def __new__(cls, *, fast_acquire: bool = False) -> LockAdapter: + return object.__new__(cls) + + def __init__(self, *, fast_acquire: bool = False): + self._fast_acquire = fast_acquire + + @property + def _lock(self) -> Lock: + if self._internal_lock is None: + self._internal_lock = get_async_backend().create_lock( + fast_acquire=self._fast_acquire + ) + + return self._internal_lock + + async def __aenter__(self) -> None: + await self._lock.acquire() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self._internal_lock is not None: + self._internal_lock.release() + + async def acquire(self) -> None: + """Acquire the lock.""" + await self._lock.acquire() + + def acquire_nowait(self) -> None: + """ + Acquire the lock, without blocking. - if self._owner_task is not None: - raise WouldBlock + :raises ~anyio.WouldBlock: if the operation would block - self._owner_task = task + """ + self._lock.acquire_nowait() def release(self) -> None: """Release the lock.""" - if self._owner_task != get_current_task(): - raise RuntimeError("The current task is not holding this lock") - - if self._waiters: - self._owner_task, event = self._waiters.popleft() - event.set() - else: - del self._owner_task + self._lock.release() def locked(self) -> bool: """Return True if the lock is currently held.""" - return self._owner_task is not None + return self._lock.locked() def statistics(self) -> LockStatistics: """ Return statistics about the current state of this lock. .. versionadded:: 3.0 + """ - return LockStatistics(self.locked(), self._owner_task, len(self._waiters)) + if self._internal_lock is None: + return LockStatistics(False, None, 0) + + return self._internal_lock.statistics() class Condition: @@ -312,7 +339,27 @@ def statistics(self) -> ConditionStatistics: class Semaphore: - def __init__(self, initial_value: int, *, max_value: int | None = None): + def __new__( + cls, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ) -> Semaphore: + try: + return get_async_backend().create_semaphore( + initial_value, max_value=max_value, fast_acquire=fast_acquire + ) + except AsyncLibraryNotFoundError: + return SemaphoreAdapter(initial_value, max_value=max_value) + + def __init__( + self, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ): if not isinstance(initial_value, int): raise TypeError("initial_value must be an integer") if initial_value < 0: @@ -325,9 +372,7 @@ def __init__(self, initial_value: int, *, max_value: int | None = None): "max_value must be equal to or higher than initial_value" ) - self._value = initial_value - self._max_value = max_value - self._waiters: deque[Event] = deque() + self._fast_acquire = fast_acquire async def __aenter__(self) -> Semaphore: await self.acquire() @@ -343,27 +388,7 @@ async def __aexit__( async def acquire(self) -> None: """Decrement the semaphore value, blocking if necessary.""" - await checkpoint_if_cancelled() - try: - self.acquire_nowait() - except WouldBlock: - event = Event() - self._waiters.append(event) - try: - await event.wait() - except BaseException: - if not event.is_set(): - self._waiters.remove(event) - else: - self.release() - - raise - else: - try: - await cancel_shielded_checkpoint() - except BaseException: - self.release() - raise + raise NotImplementedError def acquire_nowait(self) -> None: """ @@ -372,30 +397,21 @@ def acquire_nowait(self) -> None: :raises ~anyio.WouldBlock: if the operation would block """ - if self._value == 0: - raise WouldBlock - - self._value -= 1 + raise NotImplementedError def release(self) -> None: """Increment the semaphore value.""" - if self._max_value is not None and self._value == self._max_value: - raise ValueError("semaphore released too many times") - - if self._waiters: - self._waiters.popleft().set() - else: - self._value += 1 + raise NotImplementedError @property def value(self) -> int: """The current value of the semaphore.""" - return self._value + raise NotImplementedError @property def max_value(self) -> int | None: """The maximum value of the semaphore.""" - return self._max_value + raise NotImplementedError def statistics(self) -> SemaphoreStatistics: """ @@ -403,7 +419,66 @@ def statistics(self) -> SemaphoreStatistics: .. versionadded:: 3.0 """ - return SemaphoreStatistics(len(self._waiters)) + raise NotImplementedError + + +class SemaphoreAdapter(Semaphore): + _internal_semaphore: Semaphore | None = None + + def __new__( + cls, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ) -> SemaphoreAdapter: + return object.__new__(cls) + + def __init__( + self, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ) -> None: + super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire) + self._initial_value = initial_value + self._max_value = max_value + + @property + def _semaphore(self) -> Semaphore: + if self._internal_semaphore is None: + self._internal_semaphore = get_async_backend().create_semaphore( + self._initial_value, max_value=self._max_value + ) + + return self._internal_semaphore + + async def acquire(self) -> None: + await self._semaphore.acquire() + + def acquire_nowait(self) -> None: + self._semaphore.acquire_nowait() + + def release(self) -> None: + self._semaphore.release() + + @property + def value(self) -> int: + if self._internal_semaphore is None: + return self._initial_value + + return self._semaphore.value + + @property + def max_value(self) -> int | None: + return self._max_value + + def statistics(self) -> SemaphoreStatistics: + if self._internal_semaphore is None: + return SemaphoreStatistics(tasks_waiting=0) + + return self._semaphore.statistics() class CapacityLimiter: diff --git a/src/anyio/abc/_eventloop.py b/src/anyio/abc/_eventloop.py index 258d2e1d..2c73bb9f 100644 --- a/src/anyio/abc/_eventloop.py +++ b/src/anyio/abc/_eventloop.py @@ -30,7 +30,7 @@ from typing_extensions import TypeAlias if TYPE_CHECKING: - from .._core._synchronization import CapacityLimiter, Event + from .._core._synchronization import CapacityLimiter, Event, Lock, Semaphore from .._core._tasks import CancelScope from .._core._testing import TaskInfo from ..from_thread import BlockingPortal @@ -172,6 +172,22 @@ def create_task_group(cls) -> TaskGroup: def create_event(cls) -> Event: pass + @classmethod + @abstractmethod + def create_lock(cls, *, fast_acquire: bool) -> Lock: + pass + + @classmethod + @abstractmethod + def create_semaphore( + cls, + initial_value: int, + *, + max_value: int | None = None, + fast_acquire: bool = False, + ) -> Semaphore: + pass + @classmethod @abstractmethod def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter: diff --git a/tests/test_synchronization.py b/tests/test_synchronization.py index d6f68f61..c43dbe5a 100644 --- a/tests/test_synchronization.py +++ b/tests/test_synchronization.py @@ -64,6 +64,24 @@ async def task() -> None: assert not lock.locked() assert results == ["1", "2"] + async def test_fast_acquire(self) -> None: + """ + Test that fast_acquire=True does not yield back control to the event loop when + there is no contention. + + """ + other_task_called = False + + async def other_task() -> None: + nonlocal other_task_called + other_task_called = True + + lock = Lock(fast_acquire=True) + async with create_task_group() as tg: + tg.start_soon(other_task) + async with lock: + assert not other_task_called + async def test_acquire_nowait(self) -> None: lock = Lock() lock.acquire_nowait() @@ -143,6 +161,38 @@ async def acquire() -> None: task1.cancel() await asyncio.wait_for(task2, 1) + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + async def test_cancel_after_release(self) -> None: + """ + Test that a native asyncio cancellation will not cause a lock ownership + to get lost between a release() and the resumption of acquire(). + + """ + # Create the lock and acquire it right away so that any task acquiring it will + # block + lock = Lock() + lock.acquire_nowait() + + # Start a task that gets blocked on trying to acquire the semaphore + loop = asyncio.get_running_loop() + task1 = loop.create_task(lock.acquire(), name="task1") + await asyncio.sleep(0) + + # Trigger the aqcuiring task to be rescheduled, but also cancel it right away + lock.release() + task1.cancel() + statistics = lock.statistics() + assert statistics.owner + assert statistics.owner.name == "task1" + await asyncio.wait([task1], timeout=1) + + # The acquire() method should've released the semaphore because acquisition + # failed due to cancellation + statistics = lock.statistics() + assert statistics.owner is None + assert statistics.tasks_waiting == 0 + lock.acquire_nowait() + def test_instantiate_outside_event_loop( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: @@ -158,6 +208,27 @@ async def use_lock() -> None: run(use_lock, backend=anyio_backend_name, backend_options=anyio_backend_options) + async def test_owner_after_release(self) -> None: + async def taskfunc1() -> None: + await lock.acquire() + owner = lock.statistics().owner + assert owner + assert owner.name == "task1" + await event.wait() + lock.release() + owner = lock.statistics().owner + assert owner + assert owner.name == "task2" + + event = Event() + lock = Lock() + async with create_task_group() as tg: + tg.start_soon(taskfunc1, name="task1") + await wait_all_tasks_blocked() + tg.start_soon(lock.acquire, name="task2") + await wait_all_tasks_blocked() + event.set() + class TestEvent: async def test_event(self) -> None: @@ -381,6 +452,24 @@ async def acquire() -> None: assert semaphore.value == 2 + async def test_fast_acquire(self) -> None: + """ + Test that fast_acquire=True does not yield back control to the event loop when + there is no contention. + + """ + other_task_called = False + + async def other_task() -> None: + nonlocal other_task_called + other_task_called = True + + semaphore = Semaphore(1, fast_acquire=True) + async with create_task_group() as tg: + tg.start_soon(other_task) + async with semaphore: + assert not other_task_called + async def test_acquire_nowait(self) -> None: semaphore = Semaphore(1) semaphore.acquire_nowait() @@ -474,6 +563,33 @@ async def acquire() -> None: task1.cancel() await asyncio.wait_for(task2, 1) + @pytest.mark.parametrize("anyio_backend", ["asyncio"]) + async def test_cancel_after_release(self) -> None: + """ + Test that a native asyncio cancellation will not cause a semaphore ownership + to get lost between a release() and the resumption of acquire(). + + """ + # Create the semaphore in such a way that any task acquiring it will block + semaphore = Semaphore(0, max_value=1) + + # Start a task that gets blocked on trying to acquire the semaphore + loop = asyncio.get_running_loop() + task1 = loop.create_task(semaphore.acquire()) + await asyncio.sleep(0) + + # Trigger the aqcuiring task to be rescheduled, but also cancel it right away + semaphore.release() + task1.cancel() + assert semaphore.value == 0 + await asyncio.wait([task1], timeout=1) + + # The acquire() method should've released the semaphore because acquisition + # failed due to cancellation + assert semaphore.value == 1 + assert semaphore.statistics().tasks_waiting == 0 + semaphore.acquire_nowait() + def test_instantiate_outside_event_loop( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: @@ -481,7 +597,9 @@ async def use_semaphore() -> None: async with semaphore: pass - semaphore = Semaphore(1) + semaphore = Semaphore(1, max_value=3) + assert semaphore.value == 1 + assert semaphore.max_value == 3 assert semaphore.statistics().tasks_waiting == 0 run(