Skip to content

Commit

Permalink
Fix excessive tracebacks when sigint:ing a .map call [CLI-33] (#2322)
Browse files Browse the repository at this point in the history
## Changelog

* Fixes Ctrl-C on Windows clients not interrupting ongoing calls (`modal run` etc)
* Prevents excessive/internal tracebacks from appearing when SIGINT/Ctrl-C:ing an ongoing .map call
  • Loading branch information
freider authored Oct 31, 2024
1 parent f99196b commit 1a2e341
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 61 deletions.
3 changes: 1 addition & 2 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple

from google.protobuf.message import Message
from synchronicity import Interface

from modal_proto import api_pb2

Expand Down Expand Up @@ -793,7 +792,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
# Initialize objects on the app.
# This is basically only functions and classes - anything else is deprecated and will be unsupported soon
if active_app is not None:
app: App = synchronizer._translate_out(active_app, Interface.BLOCKING)
app: App = synchronizer._translate_out(active_app)
app._init_container(client, container_app)

# Hydrate all function dependencies.
Expand Down
4 changes: 1 addition & 3 deletions modal/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from dataclasses import dataclass
from typing import Any

from synchronicity.synchronizer import Interface

from modal._utils.async_utils import synchronizer
from modal_proto import api_pb2

Expand Down Expand Up @@ -69,7 +67,7 @@ def persistent_load(self, pid):
impl_class, attributes = obj_data
impl_instance = impl_class.__new__(impl_class)
impl_instance.__dict__.update(attributes)
return synchronizer._translate_out(impl_instance, interface=Interface.BLOCKING)
return synchronizer._translate_out(impl_instance)
else:
raise ExecutionError("Unknown serialization format")

Expand Down
46 changes: 22 additions & 24 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
)

import synchronicity
from synchronicity.async_utils import Runner
from synchronicity.exceptions import NestedEventLoops
from typing_extensions import ParamSpec, assert_type

from ..exception import InvalidError
Expand Down Expand Up @@ -325,6 +327,9 @@ def __del__(self):
async def athrow(self, exc):
return await self.gen.athrow(exc)

async def aclose(self):
return await self.gen.aclose()


synchronize_api(_WarnIfGeneratorIsNotConsumed)

Expand Down Expand Up @@ -371,7 +376,7 @@ class AsyncOrSyncIterable:
from an already async context, since that would otherwise deadlock the event loop
"""

def __init__(self, async_iterable: typing.AsyncIterable[Any], nested_async_message):
def __init__(self, async_iterable: typing.AsyncGenerator[Any, None], nested_async_message):
self._async_iterable = async_iterable
self.nested_async_message = nested_async_message

Expand All @@ -380,9 +385,10 @@ def __aiter__(self):

def __iter__(self):
try:
for output in run_generator_sync(self._async_iterable): # type: ignore
yield output
except NestedAsyncCalls:
with Runner() as runner:
for output in run_async_gen(runner, self._async_iterable):
yield output # type: ignore
except NestedEventLoops:
raise InvalidError(self.nested_async_message)


Expand Down Expand Up @@ -446,45 +452,37 @@ async def asyncnullcontext(*args, **kwargs):
SEND_TYPE = typing.TypeVar("SEND_TYPE")


class NestedAsyncCalls(Exception):
pass


def run_generator_sync(
def run_async_gen(
runner: Runner,
gen: typing.AsyncGenerator[YIELD_TYPE, SEND_TYPE],
) -> typing.Generator[YIELD_TYPE, SEND_TYPE, None]:
try:
asyncio.get_running_loop()
except RuntimeError:
pass # no event loop - this is what we expect!
else:
raise NestedAsyncCalls()
loop = asyncio.new_event_loop() # set up new event loop for the map so we can use async logic

"""Convert an async generator into a sync one"""
# more or less copied from synchronicity's implementation:
next_send: typing.Union[SEND_TYPE, None] = None
next_yield: YIELD_TYPE
exc: Optional[BaseException] = None
while True:
try:
if exc:
next_yield = loop.run_until_complete(gen.athrow(exc))
next_yield = runner.run(gen.athrow(exc))
else:
next_yield = loop.run_until_complete(gen.asend(next_send)) # type: ignore[arg-type]
next_yield = runner.run(gen.asend(next_send)) # type: ignore[arg-type]
except KeyboardInterrupt as e:
raise e from None
except StopAsyncIteration:
break
break # typically a graceful exit of the async generator
try:
next_send = yield next_yield
exc = None
except BaseException as err:
exc = err
loop.close()


@asynccontextmanager
async def aclosing(
agen: AsyncGenerator[T, None],
) -> AsyncGenerator[AsyncGenerator[T, None], None]:
async def aclosing(agen: AsyncGenerator[T, None]) -> AsyncGenerator[AsyncGenerator[T, None], None]:
# ensure aclose is called asynchronously after context manager is closed
# call to ensure cleanup after stateful generators since they can't
# always be cleaned up by garbage collection
try:
yield agen
finally:
Expand Down
23 changes: 13 additions & 10 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from grpclib import GRPCError, Status
from synchronicity.combined_types import MethodWithAio

from modal._utils.async_utils import aclosing
from modal_proto import api_pb2
from modal_proto.modal_api_grpc import ModalClientModal

Expand All @@ -37,7 +38,6 @@
from ._serialization import serialize, serialize_proto_params
from ._utils.async_utils import (
TaskContext,
aclosing,
async_merge,
callable_to_agen,
synchronize_api,
Expand Down Expand Up @@ -1252,15 +1252,18 @@ async def _map(
else:
count_update_callback = None

async for item in _map_invocation(
self, # type: ignore
input_queue,
self._client,
order_outputs,
return_exceptions,
count_update_callback,
):
yield item
async with aclosing(
_map_invocation(
self, # type: ignore
input_queue,
self._client,
order_outputs,
return_exceptions,
count_update_callback,
)
) as stream:
async for item in stream:
yield item

async def _call_function(self, args, kwargs) -> ReturnType:
invocation = await _Invocation.create(
Expand Down
7 changes: 5 additions & 2 deletions modal/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from google.protobuf.message import Message

from modal._utils.async_utils import aclosing

from ._resolver import Resolver
from ._utils.async_utils import synchronize_api
from .client import _Client
Expand Down Expand Up @@ -255,7 +257,8 @@ def live_method_gen(method):
@wraps(method)
async def wrapped(self, *args, **kwargs):
await self.resolve()
async for item in method(self, *args, **kwargs):
yield item
async with aclosing(method(self, *args, **kwargs)) as stream:
async for item in stream:
yield item

return wrapped
10 changes: 6 additions & 4 deletions modal/parallel_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ async def get_all_outputs():
async def get_all_outputs_and_clean_up():
assert client.stub
try:
async for item in get_all_outputs():
yield item
async with aclosing(get_all_outputs()) as output_items:
async for item in output_items:
yield item
finally:
# "ack" that we have all outputs we are interested in and let backend clear results
request = api_pb2.FunctionGetOutputsRequest(
Expand Down Expand Up @@ -350,8 +351,9 @@ async def feed_queue():
# they accept executable code in the form of
# iterators that we don't want to run inside the synchronicity thread.
# Instead, we delegate to `._map()` with a safer Queue as input
async for output in self._map.aio(raw_input_queue, order_outputs, return_exceptions): # type: ignore[reportFunctionMemberAccess]
yield output
async with aclosing(self._map.aio(raw_input_queue, order_outputs, return_exceptions)) as map_output_stream:
async for output in map_output_stream:
yield output
finally:
feed_input_task.cancel() # should only be needed in case of exceptions

Expand Down
3 changes: 1 addition & 2 deletions modal/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from multiprocessing.synchronize import Event
from typing import TYPE_CHECKING, AsyncGenerator, Optional, Set, TypeVar

from synchronicity import Interface
from synchronicity.async_wrap import asynccontextmanager

from modal._output import OutputManager
Expand All @@ -29,7 +28,7 @@
def _run_serve(app_ref: str, existing_app_id: str, is_ready: Event, environment_name: str, show_progress: bool):
# subprocess entrypoint
_app = import_app(app_ref)
blocking_app = synchronizer._translate_out(_app, Interface.BLOCKING)
blocking_app = synchronizer._translate_out(_app)

with enable_output(show_progress=show_progress):
serve_update(blocking_app, existing_app_id, is_ready, environment_name)
Expand Down
1 change: 1 addition & 0 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ notebook==6.5.1
jupytext==1.14.1
pyright==1.1.351
pdm==2.12.4 # used for testing pdm cache behavior w/ automounts
console-ctrl==0.1.0
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ install_requires =
grpclib==0.4.7
protobuf>=3.19,<5.0,!=4.24.0
rich>=12.0.0
synchronicity~=0.8.2
synchronicity~=0.9.1
toml
typer>=0.9
types-certifi
Expand Down
65 changes: 65 additions & 0 deletions test/async_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import os
import platform
import pytest
import subprocess
import sys
import textwrap
from test import helpers

import pytest_asyncio
from synchronicity import Synchronizer
Expand Down Expand Up @@ -717,3 +721,64 @@ async def foo():
async for item in callable_to_agen(foo):
result.append(item)
assert result == [await foo()]


def test_sigint_run_async_gen_shuts_down_gracefully():
code = textwrap.dedent(
"""
import asyncio
import time
from itertools import count
from synchronicity.async_utils import Runner
from modal._utils.async_utils import run_async_gen
async def async_gen():
print("enter")
try:
for i in count():
yield i
await asyncio.sleep(0.1)
finally:
# this could be either CancelledError or GeneratorExit depending on timing
# CancelledError happens if sigint is during this generator's await
# GeneratorExit is during the yielded block in the sync caller
print("cancel")
await asyncio.sleep(0.1)
print("bye")
try:
with Runner() as runner:
for res in run_async_gen(runner, async_gen()):
print("res", res)
except KeyboardInterrupt:
print("KeyboardInterrupt")
"""
)

p = helpers.PopenWithCtrlC(
[sys.executable, "-u", "-c", code],
encoding="utf8",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

def line():
s = p.stdout.readline().rstrip("\n")
if s == "":
print(p.stderr.read())
raise Exception("no stdout")
print(s)
return s

assert line() == "enter"
assert line() == "res 0"
assert line() == "res 1"

p.send_ctrl_c()
print("sent ctrl-C")
while (nextline := line()).startswith("res"):
pass
assert nextline == "cancel"
assert line() == "bye"
assert line() == "KeyboardInterrupt"
assert p.wait() == 0
assert p.stdout.read() == ""
assert p.stderr.read() == ""
Loading

0 comments on commit 1a2e341

Please sign in to comment.