diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 5122d7c5..76a400c1 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -701,26 +701,6 @@ def started(self, value: T_contra | None = None) -> None: _task_states[task].parent_id = self._parent_id -async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None: - tasks = set(tasks) - waiter = get_running_loop().create_future() - - def on_completion(task: asyncio.Task[object]) -> None: - tasks.discard(task) - if not tasks and not waiter.done(): - waiter.set_result(None) - - for task in tasks: - task.add_done_callback(on_completion) - del task - - try: - await waiter - finally: - while tasks: - tasks.pop().remove_done_callback(on_completion) - - if sys.version_info >= (3, 12): _eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__ else: @@ -733,6 +713,7 @@ def __init__(self) -> None: self._active = False self._exceptions: list[BaseException] = [] self._tasks: set[asyncio.Task] = set() + self._on_completed_fut: asyncio.Future[None] | None = None async def __aenter__(self) -> TaskGroup: self.cancel_scope.__enter__() @@ -751,12 +732,15 @@ async def __aexit__( if not isinstance(exc_val, CancelledError): self._exceptions.append(exc_val) + loop = get_running_loop() try: if self._tasks: with CancelScope() as wait_scope: while self._tasks: + self._on_completed_fut = loop.create_future() + try: - await _wait(self._tasks) + await self._on_completed_fut except CancelledError as exc: # Shield the scope against further cancellation attempts, # as they're not productive (#695) @@ -771,6 +755,8 @@ async def __aexit__( and not is_anyio_cancellation(exc) ): exc_val = exc + + self._on_completed_fut = None else: # If there are no child tasks to wait on, run at least one checkpoint # anyway @@ -808,6 +794,12 @@ def task_done(_task: asyncio.Task) -> None: self._tasks.remove(task) del _task_states[_task] + if self._on_completed_fut is not None and not self._tasks: + try: + self._on_completed_fut.set_result(None) + except asyncio.InvalidStateError: + pass + try: exc = _task.exception() except CancelledError as e: