diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 8a45e0591..538122a58 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -3,6 +3,7 @@ import base64 import importlib import inspect +import queue import signal import sys import threading @@ -60,6 +61,51 @@ class ImportedFunction: function: _Function +class DaemonizedThreadPool: + # Used instead of ThreadPoolExecutor, since the latter won't allow + # the interpreter to shut down before the currently running tasks + # have finished + def __init__(self, max_threads): + self.max_threads = max_threads + + def __enter__(self): + self.spawned_workers = 0 + self.inputs: queue.Queue[Any] = queue.Queue() + self.finished = threading.Event() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.finished.set() + + if exc_type is None: + self.inputs.join() + else: + # special case - allows us to exit the + if self.inputs.unfinished_tasks: + logger.info( + f"Exiting DaemonizedThreadPool with {self.inputs.unfinished_tasks} active inputs due to exception: {repr(exc_type)}" + ) + + def submit(self, func, *args): + def worker_thread(): + while not self.finished.is_set(): + try: + _func, _args = self.inputs.get(timeout=1) + except queue.Empty: + continue + try: + _func(*_args) + except BaseException: + logger.exception(f"Exception raised by {_func} in DaemonizedThreadPool worker!") + self.inputs.task_done() + + if self.spawned_workers < self.max_threads: + threading.Thread(target=worker_thread, daemon=True).start() + self.spawned_workers += 1 + + self.inputs.put((func, args)) + + class UserCodeEventLoop: """Run an async event loop as a context manager and handle signals. @@ -99,6 +145,7 @@ def _sigint_handler(): # first sigint is graceful task.cancel() return + raise KeyboardInterrupt() # this should normally not happen, but the second sigint would "hard kill" the event loop! ignore_sigint = signal.getsignal(signal.SIGINT) == signal.SIG_IGN @@ -122,102 +169,12 @@ def _sigint_handler(): self.loop.remove_signal_handler(signal.SIGINT) -def call_function_sync( +def call_function( + user_code_event_loop: UserCodeEventLoop, container_io_manager, #: ContainerIOManager, TODO: this type is generated at runtime imp_fun: ImportedFunction, ): - def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None: - started_at = time.time() - reset_context = _set_current_context_ids(input_id, function_call_id) - with container_io_manager.handle_input_exception(input_id, started_at): - logger.debug(f"Starting input {input_id} (sync)") - res = imp_fun.fun(*args, **kwargs) - logger.debug(f"Finished input {input_id} (sync)") - - # TODO(erikbern): any exception below shouldn't be considered a user exception - if imp_fun.is_generator: - if not inspect.isgenerator(res): - raise InvalidError(f"Generator function returned value of type {type(res)}") - - # Send up to this many outputs at a time. - generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024) - generator_output_task = container_io_manager.generator_output_task( - function_call_id, - imp_fun.data_format, - generator_queue, - _future=True, # Synchronicity magic to return a future. - ) - - item_count = 0 - for value in res: - container_io_manager._queue_put(generator_queue, value) - item_count += 1 - - container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL) - generator_output_task.result() # Wait to finish sending generator outputs. - message = api_pb2.GeneratorDone(items_total=item_count) - container_io_manager.push_output(input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE) - else: - if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): - raise InvalidError( - f"Sync (non-generator) function return value of type {type(res)}." - " You might need to use @app.function(..., is_generator=True)." - ) - container_io_manager.push_output(input_id, started_at, res, imp_fun.data_format) - reset_context() - - if imp_fun.input_concurrency > 1: - # We can't use `concurrent.futures.ThreadPoolExecutor` here because in Python 3.11+, this - # class has no workaround that allows us to exit the Python interpreter process without - # waiting for the worker threads to finish. We need this behavior on SIGINT. - - import queue - import threading - - spawned_workers = 0 - inputs: queue.Queue[Any] = queue.Queue() - finished = threading.Event() - - def worker_thread(): - while not finished.is_set(): - try: - args = inputs.get(timeout=1) - except queue.Empty: - continue - try: - run_input(*args) - except BaseException: - # This should basically never happen, since only KeyboardInterrupt is the only error that can - # bubble out of from handle_input_exception and those wouldn't be raised outside the main thread - pass - inputs.task_done() - - for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs( - imp_fun.input_concurrency - ): - if spawned_workers < imp_fun.input_concurrency: - threading.Thread(target=worker_thread, daemon=True).start() - spawned_workers += 1 - inputs.put((input_id, function_call_id, args, kwargs)) - - finished.set() - inputs.join() - - else: - for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs( - imp_fun.input_concurrency - ): - try: - run_input(input_id, function_call_id, args, kwargs) - except: - raise - - -async def call_function_async( - container_io_manager, #: ContainerIOManager, TODO: this type is generated at runtime - imp_fun: ImportedFunction, -): - async def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None: + async def run_input_async(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None: started_at = time.time() reset_context = _set_current_context_ids(input_id, function_call_id) async with container_io_manager.handle_input_exception.aio(input_id, started_at): @@ -261,24 +218,87 @@ async def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any await container_io_manager.push_output.aio(input_id, started_at, value, imp_fun.data_format) reset_context() + def run_input_sync(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None: + started_at = time.time() + reset_context = _set_current_context_ids(input_id, function_call_id) + with container_io_manager.handle_input_exception(input_id, started_at): + logger.debug(f"Starting input {input_id} (sync)") + res = imp_fun.fun(*args, **kwargs) + logger.debug(f"Finished input {input_id} (sync)") + + # TODO(erikbern): any exception below shouldn't be considered a user exception + if imp_fun.is_generator: + if not inspect.isgenerator(res): + raise InvalidError(f"Generator function returned value of type {type(res)}") + + # Send up to this many outputs at a time. + generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024) + generator_output_task = container_io_manager.generator_output_task( + function_call_id, + imp_fun.data_format, + generator_queue, + _future=True, # Synchronicity magic to return a future. + ) + + item_count = 0 + for value in res: + container_io_manager._queue_put(generator_queue, value) + item_count += 1 + + container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL) + generator_output_task.result() # Wait to finish sending generator outputs. + message = api_pb2.GeneratorDone(items_total=item_count) + container_io_manager.push_output(input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE) + else: + if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res): + raise InvalidError( + f"Sync (non-generator) function return value of type {type(res)}." + " You might need to use @app.function(..., is_generator=True)." + ) + container_io_manager.push_output(input_id, started_at, res, imp_fun.data_format) + reset_context() + if imp_fun.input_concurrency > 1: - # all run_input coroutines will have completed by the time we leave the execution context - # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s - # for them to resolve gracefully: - async with TaskContext(0.01) as task_context: - async for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs.aio( - imp_fun.input_concurrency - ): - # Note that run_inputs_outputs will not return until the concurrency semaphore has - # released all its slots so that they can be acquired by the run_inputs_outputs finalizer - # This prevents leaving the task_context before outputs have been created - # TODO: refactor to make this a bit more easy to follow? - task_context.create_task(run_input(input_id, function_call_id, args, kwargs)) + with DaemonizedThreadPool(max_threads=imp_fun.input_concurrency) as thread_pool: + + async def run_concurrent_inputs(): + # all run_input coroutines will have completed by the time we leave the execution context + # but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s + # for them to resolve gracefully: + async with TaskContext(0.01) as task_context: + async for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs.aio( + imp_fun.input_concurrency + ): + # Note that run_inputs_outputs will not return until the concurrency semaphore has + # released all its slots so that they can be acquired by the run_inputs_outputs finalizer + # This prevents leaving the task_context before outputs have been created + # TODO: refactor to make this a bit more easy to follow? + if imp_fun.is_async: + task_context.create_task(run_input_async(input_id, function_call_id, args, kwargs)) + else: + # run sync input in thread + thread_pool.submit(run_input_sync, input_id, function_call_id, args, kwargs) + + user_code_event_loop.run(run_concurrent_inputs()) else: - async for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs.aio( + for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs( imp_fun.input_concurrency ): - await run_input(input_id, function_call_id, args, kwargs) + if imp_fun.is_async: + user_code_event_loop.run(run_input_async(input_id, function_call_id, args, kwargs)) + else: + # Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation + # during function execution. This is sent to cancel inputs from the user + def _cancel_input_signal_handler(signum, stackframe): + raise InputCancellation("Input was cancelled by user") + + usr1_handler = signal.signal(signal.SIGUSR1, _cancel_input_signal_handler) + # run this sync code in the main thread, blocking the "userland" event loop + # this lets us cancel it using a signal handler that raises an exception + try: + run_input_sync(input_id, function_call_id, args, kwargs) + finally: + signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler def import_function( @@ -551,17 +571,7 @@ def breakpoint_wrapper(): # Execute the function. try: - if imp_fun.is_async: - event_loop.run(call_function_async(container_io_manager, imp_fun)) - else: - # Set up a signal handler for `SIGUSR1`, which gets translated to an InputCancellation - # during function execution. This is sent to cancel inputs from the user. - def _cancel_input_signal_handler(signum, stackframe): - raise InputCancellation("Input was cancelled by user") - - signal.signal(signal.SIGUSR1, _cancel_input_signal_handler) - - call_function_sync(container_io_manager, imp_fun) + call_function(event_loop, container_io_manager, imp_fun) finally: # Run exit handlers. From this point onward, ignore all SIGINT signals that come from # graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that diff --git a/test/container_test.py b/test/container_test.py index 2d3482134..d7a6967ef 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -835,6 +835,7 @@ def _unwrap_concurrent_input_outputs(n_inputs: int, n_parallel: int, ret: Contai @skip_github_non_linux +@pytest.mark.timeout(5) def test_concurrent_inputs_sync_function(unix_servicer): n_inputs = 18 n_parallel = 6 @@ -1197,10 +1198,13 @@ def test_cancellation_aborts_current_input_on_match( [("delay",), ("delay_async",)], ) def test_cancellation_stops_task_with_concurrent_inputs(servicer, function_name): - # send three inputs in container: in-100, in-101, in-102 with servicer.input_lockstep() as input_lock: container_process = _run_container_process( - servicer, "test.supports.functions", function_name, inputs=[((20,), {})], allow_concurrent_inputs=2 + servicer, + "test.supports.functions", + function_name, + inputs=[((20,), {})], + allow_concurrent_inputs=2, ) input_lock.wait()