Skip to content

Commit

Permalink
Merge branch 'master' into subinterpreters
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm authored Jan 4, 2025
2 parents ea0d7a7 + 6d612a9 commit 8c13f3a
Showing 1 changed file with 13 additions and 21 deletions.
34 changes: 13 additions & 21 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,26 +701,6 @@ def started(self, value: T_contra | None = None) -> None:
_task_states[task].parent_id = self._parent_id


async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None:
tasks = set(tasks)
waiter = get_running_loop().create_future()

def on_completion(task: asyncio.Task[object]) -> None:
tasks.discard(task)
if not tasks and not waiter.done():
waiter.set_result(None)

for task in tasks:
task.add_done_callback(on_completion)
del task

try:
await waiter
finally:
while tasks:
tasks.pop().remove_done_callback(on_completion)


if sys.version_info >= (3, 12):
_eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__
else:
Expand All @@ -733,6 +713,7 @@ def __init__(self) -> None:
self._active = False
self._exceptions: list[BaseException] = []
self._tasks: set[asyncio.Task] = set()
self._on_completed_fut: asyncio.Future[None] | None = None

async def __aenter__(self) -> TaskGroup:
self.cancel_scope.__enter__()
Expand All @@ -751,12 +732,15 @@ async def __aexit__(
if not isinstance(exc_val, CancelledError):
self._exceptions.append(exc_val)

loop = get_running_loop()
try:
if self._tasks:
with CancelScope() as wait_scope:
while self._tasks:
self._on_completed_fut = loop.create_future()

try:
await _wait(self._tasks)
await self._on_completed_fut
except CancelledError as exc:
# Shield the scope against further cancellation attempts,
# as they're not productive (#695)
Expand All @@ -771,6 +755,8 @@ async def __aexit__(
and not is_anyio_cancellation(exc)
):
exc_val = exc

self._on_completed_fut = None
else:
# If there are no child tasks to wait on, run at least one checkpoint
# anyway
Expand Down Expand Up @@ -808,6 +794,12 @@ def task_done(_task: asyncio.Task) -> None:
self._tasks.remove(task)
del _task_states[_task]

if self._on_completed_fut is not None and not self._tasks:
try:
self._on_completed_fut.set_result(None)
except asyncio.InvalidStateError:
pass

try:
exc = _task.exception()
except CancelledError as e:
Expand Down

0 comments on commit 8c13f3a

Please sign in to comment.