From 236f50323ed9af0de25fe8369415440751d443ed Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 4 Jan 2025 09:52:46 +0000 Subject: [PATCH 1/6] refactor waiting for tasks to propagate from task group on the asyncio backend --- src/anyio/_backends/_asyncio.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 5122d7c5..f37be4af 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -701,26 +701,6 @@ def started(self, value: T_contra | None = None) -> None: _task_states[task].parent_id = self._parent_id -async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None: - tasks = set(tasks) - waiter = get_running_loop().create_future() - - def on_completion(task: asyncio.Task[object]) -> None: - tasks.discard(task) - if not tasks and not waiter.done(): - waiter.set_result(None) - - for task in tasks: - task.add_done_callback(on_completion) - del task - - try: - await waiter - finally: - while tasks: - tasks.pop().remove_done_callback(on_completion) - - if sys.version_info >= (3, 12): _eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__ else: @@ -733,6 +713,7 @@ def __init__(self) -> None: self._active = False self._exceptions: list[BaseException] = [] self._tasks: set[asyncio.Task] = set() + self._on_completed_fut: asyncio.Future[None] | None = None async def __aenter__(self) -> TaskGroup: self.cancel_scope.__enter__() @@ -751,12 +732,15 @@ async def __aexit__( if not isinstance(exc_val, CancelledError): self._exceptions.append(exc_val) + loop = get_running_loop() try: if self._tasks: with CancelScope() as wait_scope: while self._tasks: + if self._on_completed_fut is None: + self._on_completed_fut = loop.create_future() try: - await _wait(self._tasks) + await self._on_completed_fut except CancelledError as exc: # Shield the scope against further cancellation attempts, # as they're not productive (#695) @@ -771,6 +755,7 @@ async def __aexit__( and not is_anyio_cancellation(exc) ): exc_val = exc + self._on_completed_fut = None else: # If there are no child tasks to wait on, run at least one checkpoint # anyway @@ -808,6 +793,10 @@ def task_done(_task: asyncio.Task) -> None: self._tasks.remove(task) del _task_states[_task] + if self._on_completed_fut is not None and not self._tasks: + if not self._on_completed_fut.done(): + self._on_completed_fut.set_result(None) + try: exc = _task.exception() except CancelledError as e: From bac6fc113fea696dc49c3a6ac732f2719514f6de Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 4 Jan 2025 09:59:01 +0000 Subject: [PATCH 2/6] add news --- docs/versionhistory.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 588eebbf..e525aaeb 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -18,6 +18,8 @@ This library adheres to `Semantic Versioning 2.0 `_. (`#840 `_) - Fixed return type annotation of various context managers' ``__exit__`` method (`#847 `_; PR by @Enegg) +- Refactored TaskGroup task waiting on the asyncio backend + (`#854 `_; PR by @graingert) **4.7.0** From 989275017462512ebe00ce965f6731570c97246b Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 4 Jan 2025 13:22:47 +0000 Subject: [PATCH 3/6] remove news --- docs/versionhistory.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index e525aaeb..588eebbf 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -18,8 +18,6 @@ This library adheres to `Semantic Versioning 2.0 `_. (`#840 `_) - Fixed return type annotation of various context managers' ``__exit__`` method (`#847 `_; PR by @Enegg) -- Refactored TaskGroup task waiting on the asyncio backend - (`#854 `_; PR by @graingert) **4.7.0** From 99dd7ef6f4c4bfdcd2754f7b38a73948d949fbcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 4 Jan 2025 15:35:23 +0200 Subject: [PATCH 4/6] Apply suggestions from code review --- src/anyio/_backends/_asyncio.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index f37be4af..7a3afb58 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -739,6 +739,7 @@ async def __aexit__( while self._tasks: if self._on_completed_fut is None: self._on_completed_fut = loop.create_future() + try: await self._on_completed_fut except CancelledError as exc: @@ -755,6 +756,7 @@ async def __aexit__( and not is_anyio_cancellation(exc) ): exc_val = exc + self._on_completed_fut = None else: # If there are no child tasks to wait on, run at least one checkpoint From d30f4d4ecf70c36145e646a84de85f0e92a9d1eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 4 Jan 2025 16:11:44 +0200 Subject: [PATCH 5/6] Update src/anyio/_backends/_asyncio.py Co-authored-by: Thomas Grainger --- src/anyio/_backends/_asyncio.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 7a3afb58..0d8739b8 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -737,8 +737,7 @@ async def __aexit__( if self._tasks: with CancelScope() as wait_scope: while self._tasks: - if self._on_completed_fut is None: - self._on_completed_fut = loop.create_future() + self._on_completed_fut = loop.create_future() try: await self._on_completed_fut From a9737421e2f2fed1d9d08161ffe443e2cf7509cf Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 4 Jan 2025 16:43:45 +0000 Subject: [PATCH 6/6] ask for forgiveness, rather than look before you leap --- src/anyio/_backends/_asyncio.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 0d8739b8..76a400c1 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -795,8 +795,10 @@ def task_done(_task: asyncio.Task) -> None: del _task_states[_task] if self._on_completed_fut is not None and not self._tasks: - if not self._on_completed_fut.done(): + try: self._on_completed_fut.set_result(None) + except asyncio.InvalidStateError: + pass try: exc = _task.exception()