diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 2de24d702..c660367c9 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple from google.protobuf.message import Message -from synchronicity import Interface from modal_proto import api_pb2 @@ -793,7 +792,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): # Initialize objects on the app. # This is basically only functions and classes - anything else is deprecated and will be unsupported soon if active_app is not None: - app: App = synchronizer._translate_out(active_app, Interface.BLOCKING) + app: App = synchronizer._translate_out(active_app) app._init_container(client, container_app) # Hydrate all function dependencies. diff --git a/modal/_serialization.py b/modal/_serialization.py index 33d22bc28..976725b66 100644 --- a/modal/_serialization.py +++ b/modal/_serialization.py @@ -5,8 +5,6 @@ from dataclasses import dataclass from typing import Any -from synchronicity.synchronizer import Interface - from modal._utils.async_utils import synchronizer from modal_proto import api_pb2 @@ -69,7 +67,7 @@ def persistent_load(self, pid): impl_class, attributes = obj_data impl_instance = impl_class.__new__(impl_class) impl_instance.__dict__.update(attributes) - return synchronizer._translate_out(impl_instance, interface=Interface.BLOCKING) + return synchronizer._translate_out(impl_instance) else: raise ExecutionError("Unknown serialization format") diff --git a/modal/_utils/async_utils.py b/modal/_utils/async_utils.py index 42b7dfa90..3a7747118 100644 --- a/modal/_utils/async_utils.py +++ b/modal/_utils/async_utils.py @@ -24,6 +24,8 @@ ) import synchronicity +from synchronicity.async_utils import Runner +from synchronicity.exceptions import NestedEventLoops from typing_extensions import ParamSpec, assert_type from ..exception import InvalidError @@ -325,6 +327,9 @@ def __del__(self): async def athrow(self, exc): return await self.gen.athrow(exc) + async def aclose(self): + return await self.gen.aclose() + synchronize_api(_WarnIfGeneratorIsNotConsumed) @@ -371,7 +376,7 @@ class AsyncOrSyncIterable: from an already async context, since that would otherwise deadlock the event loop """ - def __init__(self, async_iterable: typing.AsyncIterable[Any], nested_async_message): + def __init__(self, async_iterable: typing.AsyncGenerator[Any, None], nested_async_message): self._async_iterable = async_iterable self.nested_async_message = nested_async_message @@ -380,9 +385,10 @@ def __aiter__(self): def __iter__(self): try: - for output in run_generator_sync(self._async_iterable): # type: ignore - yield output - except NestedAsyncCalls: + with Runner() as runner: + for output in run_async_gen(runner, self._async_iterable): + yield output # type: ignore + except NestedEventLoops: raise InvalidError(self.nested_async_message) @@ -446,21 +452,11 @@ async def asyncnullcontext(*args, **kwargs): SEND_TYPE = typing.TypeVar("SEND_TYPE") -class NestedAsyncCalls(Exception): - pass - - -def run_generator_sync( +def run_async_gen( + runner: Runner, gen: typing.AsyncGenerator[YIELD_TYPE, SEND_TYPE], ) -> typing.Generator[YIELD_TYPE, SEND_TYPE, None]: - try: - asyncio.get_running_loop() - except RuntimeError: - pass # no event loop - this is what we expect! - else: - raise NestedAsyncCalls() - loop = asyncio.new_event_loop() # set up new event loop for the map so we can use async logic - + """Convert an async generator into a sync one""" # more or less copied from synchronicity's implementation: next_send: typing.Union[SEND_TYPE, None] = None next_yield: YIELD_TYPE @@ -468,23 +464,25 @@ def run_generator_sync( while True: try: if exc: - next_yield = loop.run_until_complete(gen.athrow(exc)) + next_yield = runner.run(gen.athrow(exc)) else: - next_yield = loop.run_until_complete(gen.asend(next_send)) # type: ignore[arg-type] + next_yield = runner.run(gen.asend(next_send)) # type: ignore[arg-type] + except KeyboardInterrupt as e: + raise e from None except StopAsyncIteration: - break + break # typically a graceful exit of the async generator try: next_send = yield next_yield exc = None except BaseException as err: exc = err - loop.close() @asynccontextmanager -async def aclosing( - agen: AsyncGenerator[T, None], -) -> AsyncGenerator[AsyncGenerator[T, None], None]: +async def aclosing(agen: AsyncGenerator[T, None]) -> AsyncGenerator[AsyncGenerator[T, None], None]: + # ensure aclose is called asynchronously after context manager is closed + # call to ensure cleanup after stateful generators since they can't + # always be cleaned up by garbage collection try: yield agen finally: diff --git a/modal/functions.py b/modal/functions.py index 8a2807c51..6d614901c 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -27,6 +27,7 @@ from grpclib import GRPCError, Status from synchronicity.combined_types import MethodWithAio +from modal._utils.async_utils import aclosing from modal_proto import api_pb2 from modal_proto.modal_api_grpc import ModalClientModal @@ -37,7 +38,6 @@ from ._serialization import serialize, serialize_proto_params from ._utils.async_utils import ( TaskContext, - aclosing, async_merge, callable_to_agen, synchronize_api, @@ -1252,15 +1252,18 @@ async def _map( else: count_update_callback = None - async for item in _map_invocation( - self, # type: ignore - input_queue, - self._client, - order_outputs, - return_exceptions, - count_update_callback, - ): - yield item + async with aclosing( + _map_invocation( + self, # type: ignore + input_queue, + self._client, + order_outputs, + return_exceptions, + count_update_callback, + ) + ) as stream: + async for item in stream: + yield item async def _call_function(self, args, kwargs) -> ReturnType: invocation = await _Invocation.create( diff --git a/modal/object.py b/modal/object.py index 01aaab9ae..6a1cd28e5 100644 --- a/modal/object.py +++ b/modal/object.py @@ -5,6 +5,8 @@ from google.protobuf.message import Message +from modal._utils.async_utils import aclosing + from ._resolver import Resolver from ._utils.async_utils import synchronize_api from .client import _Client @@ -255,7 +257,8 @@ def live_method_gen(method): @wraps(method) async def wrapped(self, *args, **kwargs): await self.resolve() - async for item in method(self, *args, **kwargs): - yield item + async with aclosing(method(self, *args, **kwargs)) as stream: + async for item in stream: + yield item return wrapped diff --git a/modal/parallel_map.py b/modal/parallel_map.py index 3aa0d9c9f..ce1aa2111 100644 --- a/modal/parallel_map.py +++ b/modal/parallel_map.py @@ -205,8 +205,9 @@ async def get_all_outputs(): async def get_all_outputs_and_clean_up(): assert client.stub try: - async for item in get_all_outputs(): - yield item + async with aclosing(get_all_outputs()) as output_items: + async for item in output_items: + yield item finally: # "ack" that we have all outputs we are interested in and let backend clear results request = api_pb2.FunctionGetOutputsRequest( @@ -350,8 +351,9 @@ async def feed_queue(): # they accept executable code in the form of # iterators that we don't want to run inside the synchronicity thread. # Instead, we delegate to `._map()` with a safer Queue as input - async for output in self._map.aio(raw_input_queue, order_outputs, return_exceptions): # type: ignore[reportFunctionMemberAccess] - yield output + async with aclosing(self._map.aio(raw_input_queue, order_outputs, return_exceptions)) as map_output_stream: + async for output in map_output_stream: + yield output finally: feed_input_task.cancel() # should only be needed in case of exceptions diff --git a/modal/serving.py b/modal/serving.py index ce9ac97a5..e9ef60692 100644 --- a/modal/serving.py +++ b/modal/serving.py @@ -5,7 +5,6 @@ from multiprocessing.synchronize import Event from typing import TYPE_CHECKING, AsyncGenerator, Optional, Set, TypeVar -from synchronicity import Interface from synchronicity.async_wrap import asynccontextmanager from modal._output import OutputManager @@ -29,7 +28,7 @@ def _run_serve(app_ref: str, existing_app_id: str, is_ready: Event, environment_name: str, show_progress: bool): # subprocess entrypoint _app = import_app(app_ref) - blocking_app = synchronizer._translate_out(_app, Interface.BLOCKING) + blocking_app = synchronizer._translate_out(_app) with enable_output(show_progress=show_progress): serve_update(blocking_app, existing_app_id, is_ready, environment_name) diff --git a/requirements.dev.txt b/requirements.dev.txt index 0734632b2..f2efb2785 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -31,3 +31,4 @@ notebook==6.5.1 jupytext==1.14.1 pyright==1.1.351 pdm==2.12.4 # used for testing pdm cache behavior w/ automounts +console-ctrl==0.1.0 diff --git a/setup.cfg b/setup.cfg index 3b698ecac..b488ae5f4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,7 @@ install_requires = grpclib==0.4.7 protobuf>=3.19,<5.0,!=4.24.0 rich>=12.0.0 - synchronicity~=0.8.2 + synchronicity~=0.9.1 toml typer>=0.9 types-certifi diff --git a/test/async_utils_test.py b/test/async_utils_test.py index 6b791e9c9..230d550df 100644 --- a/test/async_utils_test.py +++ b/test/async_utils_test.py @@ -5,6 +5,10 @@ import os import platform import pytest +import subprocess +import sys +import textwrap +from test import helpers import pytest_asyncio from synchronicity import Synchronizer @@ -717,3 +721,64 @@ async def foo(): async for item in callable_to_agen(foo): result.append(item) assert result == [await foo()] + + +def test_sigint_run_async_gen_shuts_down_gracefully(): + code = textwrap.dedent( + """ + import asyncio + import time + from itertools import count + from synchronicity.async_utils import Runner + from modal._utils.async_utils import run_async_gen + async def async_gen(): + print("enter") + try: + for i in count(): + yield i + await asyncio.sleep(0.1) + finally: + # this could be either CancelledError or GeneratorExit depending on timing + # CancelledError happens if sigint is during this generator's await + # GeneratorExit is during the yielded block in the sync caller + print("cancel") + await asyncio.sleep(0.1) + print("bye") + try: + with Runner() as runner: + for res in run_async_gen(runner, async_gen()): + print("res", res) + except KeyboardInterrupt: + print("KeyboardInterrupt") + """ + ) + + p = helpers.PopenWithCtrlC( + [sys.executable, "-u", "-c", code], + encoding="utf8", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + def line(): + s = p.stdout.readline().rstrip("\n") + if s == "": + print(p.stderr.read()) + raise Exception("no stdout") + print(s) + return s + + assert line() == "enter" + assert line() == "res 0" + assert line() == "res 1" + + p.send_ctrl_c() + print("sent ctrl-C") + while (nextline := line()).startswith("res"): + pass + assert nextline == "cancel" + assert line() == "bye" + assert line() == "KeyboardInterrupt" + assert p.wait() == 0 + assert p.stdout.read() == "" + assert p.stderr.read() == "" diff --git a/test/cli_test.py b/test/cli_test.py index 6a7065f4f..f8c4d011a 100644 --- a/test/cli_test.py +++ b/test/cli_test.py @@ -6,7 +6,6 @@ import platform import pytest import re -import signal import subprocess import sys import tempfile @@ -25,6 +24,7 @@ from modal.exception import InvalidError from modal_proto import api_pb2 +from . import helpers from .supports.skip import skip_windows dummy_app_file = """ @@ -994,15 +994,14 @@ def test_call_update_environment_suffix(servicer, set_env_client): _run(["environment", "update", "main", "--set-web-suffix", "_"]) -def _run_subprocess(cli_cmd: List[str]) -> subprocess.Popen: - p = subprocess.Popen( +def _run_subprocess(cli_cmd: List[str]) -> helpers.PopenWithCtrlC: + p = helpers.PopenWithCtrlC( [sys.executable, "-m", "modal"] + cli_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8" ) return p @pytest.mark.timeout(10) -@skip_windows("no sigint on windows") def test_keyboard_interrupt_during_app_load(servicer, server_url_env, token_env, supports_dir): ctx: InterceptionContext creating_function = threading.Event() @@ -1016,15 +1015,14 @@ async def stalling_function_create(servicer, req): p = _run_subprocess(["run", f"{supports_dir / 'hello.py'}::hello"]) creating_function.wait() - p.send_signal(signal.SIGINT) - out, err = p.communicate(timeout=1) + p.send_ctrl_c() + out, err = p.communicate(timeout=5) print(out) assert "Traceback" not in err assert "Aborting app initialization..." in out @pytest.mark.timeout(10) -@skip_windows("no sigint on windows") def test_keyboard_interrupt_during_app_run(servicer, server_url_env, token_env, supports_dir): ctx: InterceptionContext waiting_for_output = threading.Event() @@ -1038,14 +1036,13 @@ async def stalling_function_get_output(servicer, req): p = _run_subprocess(["run", f"{supports_dir / 'hello.py'}::hello"]) waiting_for_output.wait() - p.send_signal(signal.SIGINT) - out, err = p.communicate(timeout=1) + p.send_ctrl_c() + out, err = p.communicate(timeout=5) assert "App aborted. View run at https://modaltest.com/apps/ap-123" in out assert "Traceback" not in err @pytest.mark.timeout(10) -@skip_windows("no sigint on windows") def test_keyboard_interrupt_during_app_run_detach(servicer, server_url_env, token_env, supports_dir): ctx: InterceptionContext waiting_for_output = threading.Event() @@ -1059,8 +1056,8 @@ async def stalling_function_get_output(servicer, req): p = _run_subprocess(["run", "--detach", f"{supports_dir / 'hello.py'}::hello"]) waiting_for_output.wait() - p.send_signal(signal.SIGINT) - out, err = p.communicate(timeout=1) + p.send_ctrl_c() + out, err = p.communicate(timeout=5) print(out) assert "Shutting down Modal client." in out assert "The detached app keeps running. You can track its progress at:" in out diff --git a/test/helpers.py b/test/helpers.py index 00c99a18d..8eedd86fc 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,6 +1,7 @@ # Copyright Modal Labs 2023 import os import pathlib +import signal import subprocess import sys from typing import Optional, Tuple @@ -54,3 +55,25 @@ def deploy_app_externally( print(f"Deploying app failed!\n### stdout ###\n{stdout_s}\n### stderr ###\n{stderr_s}") raise Exception("Test helper failed to deploy app") return stdout_s + + +class PopenWithCtrlC(subprocess.Popen): + def __init__(self, *args, creationflags=0, **kwargs): + if sys.platform == "win32": + # needed on windows to separate ctrl-c lifecycle of subprocess from parent: + creationflags = creationflags | subprocess.CREATE_NEW_CONSOLE # type: ignore + + super().__init__(*args, **kwargs, creationflags=creationflags) + + def send_ctrl_c(self): + # platform independent way to replicate the behavior of Ctrl-C:ing a cli app + if sys.platform == "win32": + # windows doesn't support sigint, and subprocess.CTRL_C_EVENT has a bunch + # of gotchas since it's bound to a console which is the same for the parent + # process by default, and can't be sent using the python standard library + # to a separate process's console + import console_ctrl + + console_ctrl.send_ctrl_c(self.pid) # noqa [E731] + else: + self.send_signal(signal.SIGINT) diff --git a/test/supports/skip.py b/test/supports/skip.py index b1922f599..b05c0a84f 100644 --- a/test/supports/skip.py +++ b/test/supports/skip.py @@ -14,7 +14,6 @@ def skip_macos(msg: str): skip_windows_unix_socket = skip_windows("Windows doesn't have UNIX sockets") -skip_windows_signals = skip_windows("Windows doesn't support UNIX signal handling") def skip_old_py(msg: str, min_version: tuple):