Skip to content

Commit

Permalink
Refactor container entrypoint to use same handler for both sync and s…
Browse files Browse the repository at this point in the history
…ync functions (#1804)

* Refactor container entrypoint to use same handler for both sync and async functions

This will allow us to mix async and sync functions for the same class
even if all methods are executed by the same container entrypoint
input loop
  • Loading branch information
freider authored May 10, 2024
1 parent f15e521 commit 12a72a8
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 120 deletions.
246 changes: 128 additions & 118 deletions modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
import importlib
import inspect
import queue
import signal
import sys
import threading
Expand Down Expand Up @@ -60,6 +61,51 @@ class ImportedFunction:
function: _Function


class DaemonizedThreadPool:
# Used instead of ThreadPoolExecutor, since the latter won't allow
# the interpreter to shut down before the currently running tasks
# have finished
def __init__(self, max_threads):
self.max_threads = max_threads

def __enter__(self):
self.spawned_workers = 0
self.inputs: queue.Queue[Any] = queue.Queue()
self.finished = threading.Event()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.finished.set()

if exc_type is None:
self.inputs.join()
else:
# special case - allows us to exit the
if self.inputs.unfinished_tasks:
logger.info(
f"Exiting DaemonizedThreadPool with {self.inputs.unfinished_tasks} active inputs due to exception: {repr(exc_type)}"
)

def submit(self, func, *args):
def worker_thread():
while not self.finished.is_set():
try:
_func, _args = self.inputs.get(timeout=1)
except queue.Empty:
continue
try:
_func(*_args)
except BaseException:
logger.exception(f"Exception raised by {_func} in DaemonizedThreadPool worker!")
self.inputs.task_done()

if self.spawned_workers < self.max_threads:
threading.Thread(target=worker_thread, daemon=True).start()
self.spawned_workers += 1

self.inputs.put((func, args))


class UserCodeEventLoop:
"""Run an async event loop as a context manager and handle signals.
Expand Down Expand Up @@ -99,6 +145,7 @@ def _sigint_handler():
# first sigint is graceful
task.cancel()
return

raise KeyboardInterrupt() # this should normally not happen, but the second sigint would "hard kill" the event loop!

ignore_sigint = signal.getsignal(signal.SIGINT) == signal.SIG_IGN
Expand All @@ -122,102 +169,12 @@ def _sigint_handler():
self.loop.remove_signal_handler(signal.SIGINT)


def call_function_sync(
def call_function(
user_code_event_loop: UserCodeEventLoop,
container_io_manager, #: ContainerIOManager, TODO: this type is generated at runtime
imp_fun: ImportedFunction,
):
def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
started_at = time.time()
reset_context = _set_current_context_ids(input_id, function_call_id)
with container_io_manager.handle_input_exception(input_id, started_at):
logger.debug(f"Starting input {input_id} (sync)")
res = imp_fun.fun(*args, **kwargs)
logger.debug(f"Finished input {input_id} (sync)")

# TODO(erikbern): any exception below shouldn't be considered a user exception
if imp_fun.is_generator:
if not inspect.isgenerator(res):
raise InvalidError(f"Generator function returned value of type {type(res)}")

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
generator_output_task = container_io_manager.generator_output_task(
function_call_id,
imp_fun.data_format,
generator_queue,
_future=True, # Synchronicity magic to return a future.
)

item_count = 0
for value in res:
container_io_manager._queue_put(generator_queue, value)
item_count += 1

container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
generator_output_task.result() # Wait to finish sending generator outputs.
message = api_pb2.GeneratorDone(items_total=item_count)
container_io_manager.push_output(input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
else:
if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
raise InvalidError(
f"Sync (non-generator) function return value of type {type(res)}."
" You might need to use @app.function(..., is_generator=True)."
)
container_io_manager.push_output(input_id, started_at, res, imp_fun.data_format)
reset_context()

if imp_fun.input_concurrency > 1:
# We can't use `concurrent.futures.ThreadPoolExecutor` here because in Python 3.11+, this
# class has no workaround that allows us to exit the Python interpreter process without
# waiting for the worker threads to finish. We need this behavior on SIGINT.

import queue
import threading

spawned_workers = 0
inputs: queue.Queue[Any] = queue.Queue()
finished = threading.Event()

def worker_thread():
while not finished.is_set():
try:
args = inputs.get(timeout=1)
except queue.Empty:
continue
try:
run_input(*args)
except BaseException:
# This should basically never happen, since only KeyboardInterrupt is the only error that can
# bubble out of from handle_input_exception and those wouldn't be raised outside the main thread
pass
inputs.task_done()

for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs(
imp_fun.input_concurrency
):
if spawned_workers < imp_fun.input_concurrency:
threading.Thread(target=worker_thread, daemon=True).start()
spawned_workers += 1
inputs.put((input_id, function_call_id, args, kwargs))

finished.set()
inputs.join()

else:
for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs(
imp_fun.input_concurrency
):
try:
run_input(input_id, function_call_id, args, kwargs)
except:
raise


async def call_function_async(
container_io_manager, #: ContainerIOManager, TODO: this type is generated at runtime
imp_fun: ImportedFunction,
):
async def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
async def run_input_async(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
started_at = time.time()
reset_context = _set_current_context_ids(input_id, function_call_id)
async with container_io_manager.handle_input_exception.aio(input_id, started_at):
Expand Down Expand Up @@ -261,24 +218,87 @@ async def run_input(input_id: str, function_call_id: str, args: Any, kwargs: Any
await container_io_manager.push_output.aio(input_id, started_at, value, imp_fun.data_format)
reset_context()

def run_input_sync(input_id: str, function_call_id: str, args: Any, kwargs: Any) -> None:
started_at = time.time()
reset_context = _set_current_context_ids(input_id, function_call_id)
with container_io_manager.handle_input_exception(input_id, started_at):
logger.debug(f"Starting input {input_id} (sync)")
res = imp_fun.fun(*args, **kwargs)
logger.debug(f"Finished input {input_id} (sync)")

# TODO(erikbern): any exception below shouldn't be considered a user exception
if imp_fun.is_generator:
if not inspect.isgenerator(res):
raise InvalidError(f"Generator function returned value of type {type(res)}")

# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = container_io_manager._queue_create(1024)
generator_output_task = container_io_manager.generator_output_task(
function_call_id,
imp_fun.data_format,
generator_queue,
_future=True, # Synchronicity magic to return a future.
)

item_count = 0
for value in res:
container_io_manager._queue_put(generator_queue, value)
item_count += 1

container_io_manager._queue_put(generator_queue, _ContainerIOManager._GENERATOR_STOP_SENTINEL)
generator_output_task.result() # Wait to finish sending generator outputs.
message = api_pb2.GeneratorDone(items_total=item_count)
container_io_manager.push_output(input_id, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
else:
if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
raise InvalidError(
f"Sync (non-generator) function return value of type {type(res)}."
" You might need to use @app.function(..., is_generator=True)."
)
container_io_manager.push_output(input_id, started_at, res, imp_fun.data_format)
reset_context()

if imp_fun.input_concurrency > 1:
# all run_input coroutines will have completed by the time we leave the execution context
# but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
# for them to resolve gracefully:
async with TaskContext(0.01) as task_context:
async for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs.aio(
imp_fun.input_concurrency
):
# Note that run_inputs_outputs will not return until the concurrency semaphore has
# released all its slots so that they can be acquired by the run_inputs_outputs finalizer
# This prevents leaving the task_context before outputs have been created
# TODO: refactor to make this a bit more easy to follow?
task_context.create_task(run_input(input_id, function_call_id, args, kwargs))
with DaemonizedThreadPool(max_threads=imp_fun.input_concurrency) as thread_pool:

async def run_concurrent_inputs():
# all run_input coroutines will have completed by the time we leave the execution context
# but the wrapping *tasks* may not yet have been resolved, so we add a 0.01s
# for them to resolve gracefully:
async with TaskContext(0.01) as task_context:
async for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs.aio(
imp_fun.input_concurrency
):
# Note that run_inputs_outputs will not return until the concurrency semaphore has
# released all its slots so that they can be acquired by the run_inputs_outputs finalizer
# This prevents leaving the task_context before outputs have been created
# TODO: refactor to make this a bit more easy to follow?
if imp_fun.is_async:
task_context.create_task(run_input_async(input_id, function_call_id, args, kwargs))
else:
# run sync input in thread
thread_pool.submit(run_input_sync, input_id, function_call_id, args, kwargs)

user_code_event_loop.run(run_concurrent_inputs())
else:
async for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs.aio(
for input_id, function_call_id, args, kwargs in container_io_manager.run_inputs_outputs(
imp_fun.input_concurrency
):
await run_input(input_id, function_call_id, args, kwargs)
if imp_fun.is_async:
user_code_event_loop.run(run_input_async(input_id, function_call_id, args, kwargs))
else:
# Set up a custom signal handler for `SIGUSR1`, which gets translated to an InputCancellation
# during function execution. This is sent to cancel inputs from the user
def _cancel_input_signal_handler(signum, stackframe):
raise InputCancellation("Input was cancelled by user")

usr1_handler = signal.signal(signal.SIGUSR1, _cancel_input_signal_handler)
# run this sync code in the main thread, blocking the "userland" event loop
# this lets us cancel it using a signal handler that raises an exception
try:
run_input_sync(input_id, function_call_id, args, kwargs)
finally:
signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler


def import_function(
Expand Down Expand Up @@ -551,17 +571,7 @@ def breakpoint_wrapper():

# Execute the function.
try:
if imp_fun.is_async:
event_loop.run(call_function_async(container_io_manager, imp_fun))
else:
# Set up a signal handler for `SIGUSR1`, which gets translated to an InputCancellation
# during function execution. This is sent to cancel inputs from the user.
def _cancel_input_signal_handler(signum, stackframe):
raise InputCancellation("Input was cancelled by user")

signal.signal(signal.SIGUSR1, _cancel_input_signal_handler)

call_function_sync(container_io_manager, imp_fun)
call_function(event_loop, container_io_manager, imp_fun)
finally:
# Run exit handlers. From this point onward, ignore all SIGINT signals that come from
# graceful shutdowns originating on the worker, as well as stray SIGUSR1 signals that
Expand Down
8 changes: 6 additions & 2 deletions test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,7 @@ def _unwrap_concurrent_input_outputs(n_inputs: int, n_parallel: int, ret: Contai


@skip_github_non_linux
@pytest.mark.timeout(5)
def test_concurrent_inputs_sync_function(unix_servicer):
n_inputs = 18
n_parallel = 6
Expand Down Expand Up @@ -1197,10 +1198,13 @@ def test_cancellation_aborts_current_input_on_match(
[("delay",), ("delay_async",)],
)
def test_cancellation_stops_task_with_concurrent_inputs(servicer, function_name):
# send three inputs in container: in-100, in-101, in-102
with servicer.input_lockstep() as input_lock:
container_process = _run_container_process(
servicer, "test.supports.functions", function_name, inputs=[((20,), {})], allow_concurrent_inputs=2
servicer,
"test.supports.functions",
function_name,
inputs=[((20,), {})],
allow_concurrent_inputs=2,
)
input_lock.wait()

Expand Down

0 comments on commit 12a72a8

Please sign in to comment.