Skip to content

Commit

Permalink
Fix Ctrl-C of modal run etc. on Windows
Browse files Browse the repository at this point in the history
* Uses updated synchronicity version that properly supports windows Ctrl-C interruption (and has some other fixes)
* Enable Ctrl-C tests for windows
  • Loading branch information
freider authored Oct 31, 2024
1 parent b3be514 commit 9d1fc38
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 154 deletions.
3 changes: 1 addition & 2 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple

from google.protobuf.message import Message
from synchronicity import Interface

from modal_proto import api_pb2

Expand Down Expand Up @@ -793,7 +792,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
# Initialize objects on the app.
# This is basically only functions and classes - anything else is deprecated and will be unsupported soon
if active_app is not None:
app: App = synchronizer._translate_out(active_app, Interface.BLOCKING)
app: App = synchronizer._translate_out(active_app)
app._init_container(client, container_app)

# Hydrate all function dependencies.
Expand Down
4 changes: 1 addition & 3 deletions modal/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from dataclasses import dataclass
from typing import Any

from synchronicity.synchronizer import Interface

from modal._utils.async_utils import synchronizer
from modal_proto import api_pb2

Expand Down Expand Up @@ -69,7 +67,7 @@ def persistent_load(self, pid):
impl_class, attributes = obj_data
impl_instance = impl_class.__new__(impl_class)
impl_instance.__dict__.update(attributes)
return synchronizer._translate_out(impl_instance, interface=Interface.BLOCKING)
return synchronizer._translate_out(impl_instance)
else:
raise ExecutionError("Unknown serialization format")

Expand Down
144 changes: 26 additions & 118 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import concurrent.futures
import functools
import inspect
import signal
import sys
import threading
import time
import typing
from contextlib import asynccontextmanager
Expand All @@ -28,6 +25,8 @@
)

import synchronicity
from synchronicity.async_utils import Runner
from synchronicity.exceptions import NestedEventLoops
from typing_extensions import ParamSpec

from ..exception import InvalidError
Expand Down Expand Up @@ -388,9 +387,9 @@ def __aiter__(self):
def __iter__(self):
try:
with Runner() as runner:
for output in runner.run_async_gen(self._async_iterable):
for output in run_async_gen(runner, self._async_iterable):
yield output # type: ignore
except NestedAsyncCalls:
except NestedEventLoops:
raise InvalidError(self.nested_async_message)


Expand Down Expand Up @@ -454,121 +453,30 @@ async def asyncnullcontext(*args, **kwargs):
SEND_TYPE = typing.TypeVar("SEND_TYPE")


class NestedAsyncCalls(Exception):
pass


class Runner:
"""Simplified backport of asyncio.Runner from Python 3.11
Like asyncio.run() but allows multiple calls to the same event loop
before teardown.
Difference from running new_event_loop().run_until_complete is that
this catches SIGINTs and propagates it as task cancellations rather
than raising KeyboardInterrupt inside of the event loop code.
"""

# TODO: unify this with modal._container_entrypoint.UserCodeEventLoop
# which does very similar things but has some additional SIGUSR1
# logic

def __enter__(self) -> "Runner":
try:
asyncio.get_running_loop()
except RuntimeError:
pass # no event loop - this is what we expect!
else:
raise NestedAsyncCalls()

self._loop = asyncio.new_event_loop()
return self

def __exit__(self, exc_type, exc_value, traceback):
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
if sys.version_info[:2] >= (3, 9):
# Introduced in Python 3.9
self._loop.run_until_complete(self._loop.shutdown_default_executor())

self._loop.close()
return False

def run(self, coro: typing.Awaitable[T]) -> T:
is_main_thread = threading.current_thread() == threading.main_thread()
self._num_sigints = 0

coro_task = asyncio.ensure_future(coro, loop=self._loop)

async def wrapper_coro():
# this wrapper is needed since run_coroutine_threadsafe *only* accepts coroutines
return await coro_task

def _sigint_handler(signum, frame):
# cancel the task in order to have run_until_complete return soon and
# prevent a bunch of unwanted tracebacks when shutting down the
# event loop.

# this basically replicates the sigint handler installed by asyncio.run()
self._num_sigints += 1
if self._num_sigints == 1:
# first sigint is graceful
self._loop.call_soon_threadsafe(coro_task.cancel)
return

# this should normally not happen, but the second sigint would "hard kill" the event loop
# by raising KeyboardInterrupt inside of it
raise KeyboardInterrupt()

original_sigint_handler = None
def run_async_gen(
runner: Runner,
gen: typing.AsyncGenerator[YIELD_TYPE, SEND_TYPE],
) -> typing.Generator[YIELD_TYPE, SEND_TYPE, None]:
"""Convert an async generator into a sync one"""
# more or less copied from synchronicity's implementation:
next_send: typing.Union[SEND_TYPE, None] = None
next_yield: YIELD_TYPE
exc: Optional[BaseException] = None
while True:
try:
# only install signal handler if running from main thread and we haven't disabled sigint
handle_sigint = is_main_thread and signal.getsignal(signal.SIGINT) != signal.SIG_IGN

if handle_sigint:
# intentionally not using _loop.add_signal_handler since it's slow (?)
# and not available on Windows. We just don't want the sigint to
# mess with the event loop anyways
original_sigint_handler = signal.signal(signal.SIGINT, _sigint_handler)
except KeyboardInterrupt:
# this is quite unlikely, but with bad timing we could get interrupted before
# installing the sigint handler and this has happened repeatedly in unit tests
_sigint_handler(signal.SIGINT, None)

if exc:
next_yield = runner.run(gen.athrow(exc))
else:
next_yield = runner.run(gen.asend(next_send)) # type: ignore[arg-type]
except KeyboardInterrupt as e:
raise e from None
except StopAsyncIteration:
break # typically a graceful exit of the async generator
try:
return self._loop.run_until_complete(wrapper_coro())
except asyncio.CancelledError:
if self._num_sigints > 0:
raise KeyboardInterrupt() # might want to use original_sigint_handler here instead?
raise # "internal" cancellations, not triggered by KeyboardInterrupt
finally:
if original_sigint_handler:
# reset signal handler
signal.signal(signal.SIGINT, original_sigint_handler)

def run_async_gen(
self,
gen: typing.AsyncGenerator[YIELD_TYPE, SEND_TYPE],
) -> typing.Generator[YIELD_TYPE, SEND_TYPE, None]:
"""Convert an async generator into a sync one"""
# more or less copied from synchronicity's implementation:
next_send: typing.Union[SEND_TYPE, None] = None
next_yield: YIELD_TYPE
exc: Optional[BaseException] = None
while True:
try:
if exc:
next_yield = self.run(gen.athrow(exc))
else:
next_yield = self.run(gen.asend(next_send)) # type: ignore[arg-type]
except KeyboardInterrupt as e:
raise e from None
except StopAsyncIteration:
break # typically a graceful exit of the async generator
try:
next_send = yield next_yield
exc = None
except BaseException as err:
exc = err
next_send = yield next_yield
exc = None
except BaseException as err:
exc = err


@asynccontextmanager
Expand Down
3 changes: 1 addition & 2 deletions modal/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from multiprocessing.synchronize import Event
from typing import TYPE_CHECKING, AsyncGenerator, Optional, Set, TypeVar

from synchronicity import Interface
from synchronicity.async_wrap import asynccontextmanager

from modal._output import OutputManager
Expand All @@ -29,7 +28,7 @@
def _run_serve(app_ref: str, existing_app_id: str, is_ready: Event, environment_name: str, show_progress: bool):
# subprocess entrypoint
_app = import_app(app_ref)
blocking_app = synchronizer._translate_out(_app, Interface.BLOCKING)
blocking_app = synchronizer._translate_out(_app)

with enable_output(show_progress=show_progress):
serve_update(blocking_app, existing_app_id, is_ready, environment_name)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ install_requires =
grpclib==0.4.7
protobuf>=3.19,<5.0,!=4.24.0
rich>=12.0.0
synchronicity~=0.8.3
synchronicity~=0.9.1
toml
typer>=0.9
types-certifi
Expand Down
24 changes: 9 additions & 15 deletions test/async_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import os
import platform
import pytest
import signal
import subprocess
import sys
import textwrap
from test import helpers

import pytest_asyncio
from synchronicity import Synchronizer
Expand Down Expand Up @@ -636,7 +636,8 @@ def test_sigint_run_async_gen_shuts_down_gracefully():
import asyncio
import time
from itertools import count
from modal._utils.async_utils import Runner, aclosing
from synchronicity.async_utils import Runner
from modal._utils.async_utils import run_async_gen
async def async_gen():
print("enter")
try:
Expand All @@ -652,40 +653,33 @@ async def async_gen():
print("bye")
try:
with Runner() as runner:
for res in runner.run_async_gen(async_gen()):
for res in run_async_gen(runner, async_gen()):
print("res", res)
except KeyboardInterrupt:
print("KeyboardInterrupt")
"""
)
if sys.platform == "win32":
# workaround to be able to _test_ Ctrl-C response on windows
import console_ctrl

creationflags = subprocess.CREATE_NEW_CONSOLE # type: ignore
platform_sigint = lambda p: console_ctrl.send_ctrl_c(p.pid) # noqa [E731]
else:
creationflags = 0
platform_sigint = lambda p: p.send_signal(signal.SIGINT) # noqa [E731]

p = subprocess.Popen(
p = helpers.PopenWithCtrlC(
[sys.executable, "-u", "-c", code],
encoding="utf8",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
creationflags=creationflags,
)

def line():
s = p.stdout.readline().rstrip("\n")
if s == "":
print(p.stderr.read())
raise Exception("no stdout")
print(s)
return s

assert line() == "enter"
assert line() == "res 0"
assert line() == "res 1"

platform_sigint(p)
p.send_ctrl_c()
print("sent ctrl-C")
while (nextline := line()).startswith("res"):
pass
Expand Down
21 changes: 9 additions & 12 deletions test/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import platform
import pytest
import re
import signal
import subprocess
import sys
import tempfile
Expand All @@ -25,6 +24,7 @@
from modal.exception import InvalidError
from modal_proto import api_pb2

from . import helpers
from .supports.skip import skip_windows

dummy_app_file = """
Expand Down Expand Up @@ -981,15 +981,14 @@ def test_call_update_environment_suffix(servicer, set_env_client):
_run(["environment", "update", "main", "--set-web-suffix", "_"])


def _run_subprocess(cli_cmd: List[str]) -> subprocess.Popen:
p = subprocess.Popen(
def _run_subprocess(cli_cmd: List[str]) -> helpers.PopenWithCtrlC:
p = helpers.PopenWithCtrlC(
[sys.executable, "-m", "modal"] + cli_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf8"
)
return p


@pytest.mark.timeout(10)
@skip_windows("no sigint on windows")
def test_keyboard_interrupt_during_app_load(servicer, server_url_env, token_env, supports_dir):
ctx: InterceptionContext
creating_function = threading.Event()
Expand All @@ -1003,15 +1002,14 @@ async def stalling_function_create(servicer, req):

p = _run_subprocess(["run", f"{supports_dir / 'hello.py'}::hello"])
creating_function.wait()
p.send_signal(signal.SIGINT)
out, err = p.communicate(timeout=1)
p.send_ctrl_c()
out, err = p.communicate(timeout=5)
print(out)
assert "Traceback" not in err
assert "Aborting app initialization..." in out


@pytest.mark.timeout(10)
@skip_windows("no sigint on windows")
def test_keyboard_interrupt_during_app_run(servicer, server_url_env, token_env, supports_dir):
ctx: InterceptionContext
waiting_for_output = threading.Event()
Expand All @@ -1025,14 +1023,13 @@ async def stalling_function_get_output(servicer, req):

p = _run_subprocess(["run", f"{supports_dir / 'hello.py'}::hello"])
waiting_for_output.wait()
p.send_signal(signal.SIGINT)
out, err = p.communicate(timeout=1)
p.send_ctrl_c()
out, err = p.communicate(timeout=5)
assert "App aborted. View run at https://modaltest.com/apps/ap-123" in out
assert "Traceback" not in err


@pytest.mark.timeout(10)
@skip_windows("no sigint on windows")
def test_keyboard_interrupt_during_app_run_detach(servicer, server_url_env, token_env, supports_dir):
ctx: InterceptionContext
waiting_for_output = threading.Event()
Expand All @@ -1046,8 +1043,8 @@ async def stalling_function_get_output(servicer, req):

p = _run_subprocess(["run", "--detach", f"{supports_dir / 'hello.py'}::hello"])
waiting_for_output.wait()
p.send_signal(signal.SIGINT)
out, err = p.communicate(timeout=1)
p.send_ctrl_c()
out, err = p.communicate(timeout=5)
print(out)
assert "Shutting down Modal client." in out
assert "The detached app keeps running. You can track its progress at:" in out
Expand Down
Loading

0 comments on commit 9d1fc38

Please sign in to comment.