Skip to content

Commit

Permalink
asyncio cancellation handling + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kramstrom committed Oct 29, 2024
1 parent ca01d42 commit 9b7ac42
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 22 deletions.
69 changes: 50 additions & 19 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

import synchronicity
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, assert_type

from ..exception import InvalidError
from .logger import logger
Expand Down Expand Up @@ -561,34 +561,65 @@ class StopSentinelType:
STOP_SENTINEL = StopSentinelType()


async def async_merge(*inputs: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGenerator[T, None]:
queue: asyncio.Queue[Tuple[int, Union[ValueWrapper[T], ExceptionWrapper, StopSentinelType]]] = asyncio.Queue()
async def async_merge(*iterables: Union[AsyncIterable[T], Iterable[T]]) -> AsyncGenerator[T, None]:
queue: asyncio.Queue[Union[ValueWrapper[T], ExceptionWrapper, StopSentinelType]] = asyncio.Queue(
maxsize=len(iterables) * 10
)

async def producer(producer_id: int, iterable: Union[AsyncIterable[T], Iterable[T]]):
async def producer(iterable: Union[AsyncIterable[T], Iterable[T]]):
try:
async for item in sync_or_async_iter(iterable):
await queue.put((producer_id, ValueWrapper(item)))
await queue.put(ValueWrapper(item))
except Exception as e:
await queue.put((producer_id, ExceptionWrapper(e)))
finally:
await queue.put((producer_id, STOP_SENTINEL))
await queue.put(ExceptionWrapper(e))

tasks = [asyncio.create_task(producer(i, it)) for i, it in enumerate(inputs)]
active_producers = set(range(len(inputs)))
tasks = set([asyncio.create_task(producer(it)) for it in iterables])
new_output_task = asyncio.create_task(queue.get())

try:
while active_producers:
producer_id, item = await queue.get()
if isinstance(item, ExceptionWrapper):
raise item.value
elif isinstance(item, StopSentinelType):
active_producers.remove(producer_id)
else:
while tasks:
done, _ = await asyncio.wait(
[*tasks, new_output_task],
return_when=asyncio.FIRST_COMPLETED,
)

if new_output_task in done:
item = new_output_task.result()
if isinstance(item, ValueWrapper):
yield item.value
else:
assert_type(item, ExceptionWrapper)
raise item.value

new_output_task = asyncio.create_task(queue.get())

finished_producers = done & tasks
tasks -= finished_producers
for finished_producer in finished_producers:
# this is done in order to catch potential raised errors/cancellations
# from within worker tasks as soon as they happen.
await finished_producer

while not queue.empty():
item = await new_output_task
if isinstance(item, ValueWrapper):
yield item.value
else:
assert_type(item, ExceptionWrapper)
raise item.value

new_output_task = asyncio.create_task(queue.get())

finally:
if not new_output_task.done():
new_output_task.cancel()
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
if not task.done():
try:
task.cancel()
await task
except asyncio.CancelledError:
pass


async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]:
Expand Down
97 changes: 94 additions & 3 deletions test/async_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,52 @@ async def gen2():
assert result == [(1, 3), (2, 4)]


@pytest.mark.asyncio
async def test_async_zip_cancellation():
ev = asyncio.Event()

async def gen1():
await asyncio.sleep(0.1)
yield 1
await ev.wait()
raise asyncio.CancelledError()
yield 2

async def gen2():
yield 3
await asyncio.sleep(0.1)
yield 4

async def zip_coro():
async for _ in async_zip(gen1(), gen2()):
pass

zip_task = asyncio.create_task(zip_coro())
await asyncio.sleep(0.1)
zip_task.cancel()
with pytest.raises(asyncio.CancelledError):
await zip_task


@pytest.mark.asyncio
async def test_async_zip_producer_cancellation():
async def gen1():
await asyncio.sleep(0.1)
yield 1
raise asyncio.CancelledError()
yield 2

async def gen2():
yield 3
await asyncio.sleep(0.1)
yield 4

await asyncio.sleep(0.1)
with pytest.raises(asyncio.CancelledError):
async for _ in async_zip(gen1(), gen2()):
pass


@pytest.mark.asyncio
async def test_async_merge():
result = []
Expand Down Expand Up @@ -564,15 +610,15 @@ async def gen2():
await asyncio.sleep(0)
states.append("gen2 exit")

async with aclosing(gen1()) as g1, aclosing(gen2()) as g2, aclosing(async_merge(g1, g2)) as stream:
async with aclosing(async_merge(gen1(), gen2())) as stream:
async for _ in stream:
break

assert states == [
assert sorted(states) == [
"gen1 enter",
"gen1 exit",
"gen2 enter",
"gen2 exit",
"gen1 exit",
]


Expand Down Expand Up @@ -614,6 +660,51 @@ async def gen2():
]


@pytest.mark.asyncio
async def test_async_merge_cancellation():
ev = asyncio.Event()

async def gen1():
await asyncio.sleep(0.1)
yield 1
await ev.wait()
yield 2

async def gen2():
yield 3
await asyncio.sleep(0.1)
yield 4

async def merge_coro():
async for _ in async_merge(gen1(), gen2()):
pass

merge_task = asyncio.create_task(merge_coro())
await asyncio.sleep(0.1)
merge_task.cancel()
with pytest.raises(asyncio.CancelledError):
await merge_task


@pytest.mark.asyncio
async def test_async_merge_producer_cancellation():
async def gen1():
await asyncio.sleep(0.1)
yield 1
raise asyncio.CancelledError()
yield 2

async def gen2():
yield 3
await asyncio.sleep(0.1)
yield 4

await asyncio.sleep(0.1)
with pytest.raises(asyncio.CancelledError):
async for _ in async_merge(gen1(), gen2()):
pass


@pytest.mark.asyncio
async def test_callable_to_agen():
async def foo():
Expand Down

0 comments on commit 9b7ac42

Please sign in to comment.