diff --git a/modal/app.py b/modal/app.py index 55252fd48e..56f4af5c98 100644 --- a/modal/app.py +++ b/modal/app.py @@ -508,6 +508,7 @@ def function( bool ] = None, # Set this to True if it's a non-generator function returning a [sync/async] generator object cloud: Optional[str] = None, # Cloud provider to run the function on. Possible values are aws, gcp, oci, auto. + region: Optional[Union[str, Sequence[str]]] = None, # Region or regions to run the function on. enable_memory_snapshot: bool = False, # Enable memory checkpointing for faster cold starts. checkpointing_enabled: Optional[bool] = None, # Deprecated block_network: bool = False, # Whether to block network access @@ -579,6 +580,12 @@ def wrapped( if is_generator is None: is_generator = inspect.isgeneratorfunction(raw_f) or inspect.isasyncgenfunction(raw_f) + scheduler_placement: Optional[SchedulerPlacement] = _experimental_scheduler_placement + if region: + if scheduler_placement: + raise InvalidError("`region` and `_experimental_scheduler_placement` cannot be used together") + scheduler_placement = SchedulerPlacement(region=region) + function = _Function.from_args( info, app=self, @@ -608,9 +615,9 @@ def wrapped( allow_background_volume_commits=_allow_background_volume_commits, block_network=block_network, max_inputs=max_inputs, + scheduler_placement=scheduler_placement, _experimental_boost=_experimental_boost, _experimental_scheduler=_experimental_scheduler, - _experimental_scheduler_placement=_experimental_scheduler_placement, ) self._add_function(function) @@ -646,6 +653,7 @@ def cls( timeout: Optional[int] = None, # Maximum execution time of the function in seconds. keep_warm: Optional[int] = None, # An optional number of containers to always keep warm. cloud: Optional[str] = None, # Cloud provider to run the function on. Possible values are aws, gcp, oci, auto. + region: Optional[Union[str, Sequence[str]]] = None, # Region or regions to run the function on. enable_memory_snapshot: bool = False, # Enable memory checkpointing for faster cold starts. checkpointing_enabled: Optional[bool] = None, # Deprecated block_network: bool = False, # Whether to block network access @@ -687,6 +695,7 @@ def cls( interactive=interactive, keep_warm=keep_warm, cloud=cloud, + region=region, enable_memory_snapshot=enable_memory_snapshot, checkpointing_enabled=checkpointing_enabled, block_network=block_network, @@ -730,6 +739,7 @@ async def spawn_sandbox( workdir: Optional[str] = None, # Working directory of the sandbox. gpu: GPU_T = None, cloud: Optional[str] = None, + region: Optional[Union[str, Sequence[str]]] = None, # Region or regions to run the sandbox on. cpu: Optional[float] = None, # How many CPU cores to request. This is a soft limit. memory: Optional[ Union[int, Tuple[int, int]] @@ -769,6 +779,7 @@ async def spawn_sandbox( workdir=workdir, gpu=gpu, cloud=cloud, + region=region, cpu=cpu, memory=memory, network_file_systems=network_file_systems, diff --git a/modal/cli/run.py b/modal/cli/run.py index 46cda97544..9d04025e8b 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -344,7 +344,11 @@ def shell( ), cloud: Optional[str] = typer.Option( default=None, - help="Cloud provider to run the function on. Possible values are `aws`, `gcp`, `oci`, `auto` (if not using FUNC_REF).", + help="Cloud provider to run the shell on. Possible values are `aws`, `gcp`, `oci`, `auto` (if not using FUNC_REF).", + ), + region: Optional[str] = typer.Option( + default=None, + help="Region(s) to run the shell on. Can be a single region or a comma-separated list to choose from (if not using FUNC_REF).", ), ): """Run an interactive shell inside a Modal image. @@ -392,10 +396,19 @@ def shell( cpu=function_spec.cpu, memory=function_spec.memory, volumes=function_spec.volumes, + scheduler_placement=function_spec.scheduler_placement, _allow_background_volume_commits=True, ) else: modal_image = Image.from_registry(image, add_python=add_python) if image else None - start_shell = partial(interactive_shell, image=modal_image, cpu=cpu, memory=memory, gpu=gpu, cloud=cloud) + start_shell = partial( + interactive_shell, + image=modal_image, + cpu=cpu, + memory=memory, + gpu=gpu, + cloud=cloud, + region=region.split(",") if region else [], + ) start_shell(app, cmd=[cmd], environment_name=env, timeout=3600) diff --git a/modal/functions.py b/modal/functions.py index 89289abbed..88fb8cba0e 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -250,6 +250,7 @@ class _FunctionSpec: cloud: Optional[str] cpu: Optional[float] memory: Optional[Union[int, Tuple[int, int]]] + scheduler_placement: Optional[SchedulerPlacement] class _Function(_Object, type_prefix="fu"): @@ -303,7 +304,7 @@ def from_args( cloud: Optional[str] = None, _experimental_boost: bool = False, _experimental_scheduler: bool = False, - _experimental_scheduler_placement: Optional[SchedulerPlacement] = None, + scheduler_placement: Optional[SchedulerPlacement] = None, is_builder_function: bool = False, is_auto_snapshot: bool = False, enable_memory_snapshot: bool = False, @@ -374,6 +375,7 @@ def from_args( cloud=cloud, cpu=cpu, memory=memory, + scheduler_placement=scheduler_placement, ) if info.cls and not is_auto_snapshot: @@ -397,7 +399,7 @@ def from_args( cpu=cpu, is_builder_function=True, is_auto_snapshot=True, - _experimental_scheduler_placement=_experimental_scheduler_placement, + scheduler_placement=scheduler_placement, ) image = _Image._from_args( base_images={"base": image}, @@ -596,9 +598,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts), _experimental_boost=_experimental_boost, _experimental_scheduler=_experimental_scheduler, - _experimental_scheduler_placement=_experimental_scheduler_placement.proto - if _experimental_scheduler_placement - else None, + scheduler_placement=scheduler_placement.proto if scheduler_placement else None, ) request = api_pb2.FunctionCreateRequest( app_id=resolver.app_id, diff --git a/modal/sandbox.py b/modal/sandbox.py index ce33d967af..e6d3b90669 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -242,6 +242,7 @@ def _new( workdir: Optional[str] = None, gpu: GPU_T = None, cloud: Optional[str] = None, + region: Optional[Union[str, Sequence[str]]] = None, cpu: Optional[float] = None, memory: Optional[Union[int, Tuple[int, int]]] = None, network_file_systems: Dict[Union[str, os.PathLike], _NetworkFileSystem] = {}, @@ -261,6 +262,12 @@ def _new( raise InvalidError("network_file_systems must be a dict[str, NetworkFileSystem] where the keys are paths") validated_network_file_systems = validate_mount_points("Network file system", network_file_systems) + scheduler_placement: Optional[SchedulerPlacement] = _experimental_scheduler_placement + if region: + if scheduler_placement: + raise InvalidError("`region` and `_experimental_scheduler_placement` cannot be used together") + scheduler_placement = SchedulerPlacement(region=region) + # Validate volumes validated_volumes = validate_volumes(volumes) cloud_bucket_mounts = [(k, v) for k, v in validated_volumes if isinstance(v, _CloudBucketMount)] @@ -304,9 +311,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona volume_mounts=volume_mounts, pty_info=pty_info, _experimental_scheduler=_experimental_scheduler, - _experimental_scheduler_placement=_experimental_scheduler_placement.proto - if _experimental_scheduler_placement - else None, + scheduler_placement=scheduler_placement.proto if scheduler_placement else None, ) create_req = api_pb2.SandboxCreateRequest(app_id=resolver.app_id, definition=definition) diff --git a/modal/scheduler_placement.py b/modal/scheduler_placement.py index b84780e0e0..4248cb3020 100644 --- a/modal/scheduler_placement.py +++ b/modal/scheduler_placement.py @@ -1,5 +1,5 @@ # Copyright Modal Labs 2024 -from typing import Optional +from typing import Optional, Sequence, Union from modal_proto import api_pb2 @@ -11,7 +11,7 @@ class SchedulerPlacement: def __init__( self, - region: Optional[str] = None, + region: Optional[Union[str, Sequence[str]]] = None, zone: Optional[str] = None, spot: Optional[bool] = None, ): @@ -20,8 +20,14 @@ def __init__( if spot is not None: _lifecycle = "spot" if spot else "on-demand" + regions = [] + if region: + if isinstance(region, str): + regions = [region] + else: + regions = list(region) self.proto = api_pb2.SchedulerPlacement( - _region=region, + regions=regions, _zone=zone, _lifecycle=_lifecycle, ) diff --git a/modal_proto/api.proto b/modal_proto/api.proto index cbc3280e84..1abaa6726a 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -915,15 +915,17 @@ message Function { // If set, tasks will be scheduled using the new scheduler, which also knows // to look at fine-grained placement constraints. bool _experimental_scheduler = 49; - optional SchedulerPlacement _experimental_scheduler_placement = 50; + optional SchedulerPlacement scheduler_placement = 50; } message SchedulerPlacement { // TODO(irfansharif): // - Fold in cloud, resource needs here too. - // - Allow specifying list of regions, zones, cloud, fallback and alternative + // - Allow specifying list of zones, cloud, fallback and alternative // GPU types. - optional string _region = 1; + optional string _region = 1 [deprecated=true]; + + repeated string regions = 4; optional string _zone = 2; optional string _lifecycle = 3; // "on-demand" or "spot", else ignored } @@ -1007,7 +1009,7 @@ message FunctionGetResponse { FunctionHandleMetadata handle_metadata = 2; } - + message FunctionUpdateSchedulingParamsRequest { string function_id = 1; uint32 warm_pool_size_override = 2; @@ -1572,7 +1574,7 @@ message Sandbox { // If set, tasks will be scheduled using the new scheduler, which also knows // to look at fine-grained placement constraints. bool _experimental_scheduler = 16; - optional SchedulerPlacement _experimental_scheduler_placement = 17; + optional SchedulerPlacement scheduler_placement = 17; } message SandboxCreateRequest { diff --git a/test/scheduler_placement_test.py b/test/scheduler_placement_test.py index 2225e6766e..4e96fe7683 100644 --- a/test/scheduler_placement_test.py +++ b/test/scheduler_placement_test.py @@ -15,21 +15,45 @@ spot=False, ), ) -def f(): +def f1(): + pass + + +@app.function( + region="us-east-1", +) +def f2(): + pass + + +@app.function( + region=["us-east-1", "us-west-2"], +) +def f3(): pass def test_fn_scheduler_placement(servicer, client): with app.run(client=client): - assert len(servicer.app_functions) == 1 - fn = servicer.app_functions["fu-1"] - assert fn._experimental_scheduler - assert fn._experimental_scheduler_placement == api_pb2.SchedulerPlacement( - _region="us-east-1", + assert len(servicer.app_functions) == 3 + fn1 = servicer.app_functions["fu-1"] # f1 + assert fn1._experimental_scheduler + assert fn1.scheduler_placement == api_pb2.SchedulerPlacement( + regions=["us-east-1"], _zone="us-east-1a", _lifecycle="on-demand", ) + fn2 = servicer.app_functions["fu-2"] # f2 + assert fn2.scheduler_placement == api_pb2.SchedulerPlacement( + regions=["us-east-1"], + ) + + fn3 = servicer.app_functions["fu-3"] # f3 + assert fn3.scheduler_placement == api_pb2.SchedulerPlacement( + regions=["us-east-1", "us-west-2"], + ) + @skip_non_linux def test_sandbox_scheduler_placement(client, servicer): @@ -39,19 +63,13 @@ def test_sandbox_scheduler_placement(client, servicer): "-c", "echo bye >&2 && sleep 1 && echo hi && exit 42", timeout=600, + region="us-east-1", _experimental_scheduler=True, - _experimental_scheduler_placement=SchedulerPlacement( - region="us-east-1", - zone="us-east-1a", - spot=False, - ), ) assert len(servicer.sandbox_defs) == 1 sb_def = servicer.sandbox_defs[0] - assert sb_def._experimental_scheduler - assert sb_def._experimental_scheduler_placement == api_pb2.SchedulerPlacement( - _region="us-east-1", - _zone="us-east-1a", - _lifecycle="on-demand", + assert sb_def.scheduler_placement == api_pb2.SchedulerPlacement( + regions=["us-east-1"], ) + assert sb_def._experimental_scheduler