Skip to content

Commit

Permalink
Support specifying a list of regions to run on
Browse files Browse the repository at this point in the history
- Rename the '_experimental_scheduler_placement' proto field to just
  'scheduler_placement'.
- Disallow specifying both 'region=' and
  '_experimental_scheduler_placement='.
- Make things backwards compatible, with
  '_experimental_scheduler_placement=...' and
  'SchedulerPlacement(region=...)', both at the API level and at the proto
  level. For the latter, the field is marked as deprecated and will be
  normalized away on the server.
- Support specifying regions for 'modal shell'.
  • Loading branch information
irfansharif committed May 7, 2024
1 parent edfb05c commit e0197f2
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 35 deletions.
13 changes: 12 additions & 1 deletion modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def function(
bool
] = None, # Set this to True if it's a non-generator function returning a [sync/async] generator object
cloud: Optional[str] = None, # Cloud provider to run the function on. Possible values are aws, gcp, oci, auto.
region: Optional[Union[str, Sequence[str]]] = None, # Region or regions to run the function on.
enable_memory_snapshot: bool = False, # Enable memory checkpointing for faster cold starts.
checkpointing_enabled: Optional[bool] = None, # Deprecated
block_network: bool = False, # Whether to block network access
Expand Down Expand Up @@ -579,6 +580,12 @@ def wrapped(
if is_generator is None:
is_generator = inspect.isgeneratorfunction(raw_f) or inspect.isasyncgenfunction(raw_f)

scheduler_placement: Optional[SchedulerPlacement] = _experimental_scheduler_placement
if region:
if scheduler_placement:
raise InvalidError("`region` and `_experimental_scheduler_placement` cannot be used together")
scheduler_placement = SchedulerPlacement(region=region)

function = _Function.from_args(
info,
app=self,
Expand Down Expand Up @@ -608,9 +615,9 @@ def wrapped(
allow_background_volume_commits=_allow_background_volume_commits,
block_network=block_network,
max_inputs=max_inputs,
scheduler_placement=scheduler_placement,
_experimental_boost=_experimental_boost,
_experimental_scheduler=_experimental_scheduler,
_experimental_scheduler_placement=_experimental_scheduler_placement,
)

self._add_function(function)
Expand Down Expand Up @@ -646,6 +653,7 @@ def cls(
timeout: Optional[int] = None, # Maximum execution time of the function in seconds.
keep_warm: Optional[int] = None, # An optional number of containers to always keep warm.
cloud: Optional[str] = None, # Cloud provider to run the function on. Possible values are aws, gcp, oci, auto.
region: Optional[Union[str, Sequence[str]]] = None, # Region or regions to run the function on.
enable_memory_snapshot: bool = False, # Enable memory checkpointing for faster cold starts.
checkpointing_enabled: Optional[bool] = None, # Deprecated
block_network: bool = False, # Whether to block network access
Expand Down Expand Up @@ -687,6 +695,7 @@ def cls(
interactive=interactive,
keep_warm=keep_warm,
cloud=cloud,
region=region,
enable_memory_snapshot=enable_memory_snapshot,
checkpointing_enabled=checkpointing_enabled,
block_network=block_network,
Expand Down Expand Up @@ -730,6 +739,7 @@ async def spawn_sandbox(
workdir: Optional[str] = None, # Working directory of the sandbox.
gpu: GPU_T = None,
cloud: Optional[str] = None,
region: Optional[Union[str, Sequence[str]]] = None, # Region or regions to run the sandbox on.
cpu: Optional[float] = None, # How many CPU cores to request. This is a soft limit.
memory: Optional[
Union[int, Tuple[int, int]]
Expand Down Expand Up @@ -769,6 +779,7 @@ async def spawn_sandbox(
workdir=workdir,
gpu=gpu,
cloud=cloud,
region=region,
cpu=cpu,
memory=memory,
network_file_systems=network_file_systems,
Expand Down
17 changes: 15 additions & 2 deletions modal/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,11 @@ def shell(
),
cloud: Optional[str] = typer.Option(
default=None,
help="Cloud provider to run the function on. Possible values are `aws`, `gcp`, `oci`, `auto` (if not using FUNC_REF).",
help="Cloud provider to run the shell on. Possible values are `aws`, `gcp`, `oci`, `auto` (if not using FUNC_REF).",
),
region: Optional[str] = typer.Option(
default=None,
help="Region(s) to run the shell on. Can be a single region or a comma-separated list to choose from (if not using FUNC_REF).",
),
):
"""Run an interactive shell inside a Modal image.
Expand Down Expand Up @@ -392,10 +396,19 @@ def shell(
cpu=function_spec.cpu,
memory=function_spec.memory,
volumes=function_spec.volumes,
scheduler_placement=function_spec.scheduler_placement,
_allow_background_volume_commits=True,
)
else:
modal_image = Image.from_registry(image, add_python=add_python) if image else None
start_shell = partial(interactive_shell, image=modal_image, cpu=cpu, memory=memory, gpu=gpu, cloud=cloud)
start_shell = partial(
interactive_shell,
image=modal_image,
cpu=cpu,
memory=memory,
gpu=gpu,
cloud=cloud,
region=region.split(",") if region else [],
)

start_shell(app, cmd=[cmd], environment_name=env, timeout=3600)
10 changes: 5 additions & 5 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ class _FunctionSpec:
cloud: Optional[str]
cpu: Optional[float]
memory: Optional[Union[int, Tuple[int, int]]]
scheduler_placement: Optional[SchedulerPlacement]


class _Function(_Object, type_prefix="fu"):
Expand Down Expand Up @@ -303,7 +304,7 @@ def from_args(
cloud: Optional[str] = None,
_experimental_boost: bool = False,
_experimental_scheduler: bool = False,
_experimental_scheduler_placement: Optional[SchedulerPlacement] = None,
scheduler_placement: Optional[SchedulerPlacement] = None,
is_builder_function: bool = False,
is_auto_snapshot: bool = False,
enable_memory_snapshot: bool = False,
Expand Down Expand Up @@ -374,6 +375,7 @@ def from_args(
cloud=cloud,
cpu=cpu,
memory=memory,
scheduler_placement=scheduler_placement,
)

if info.cls and not is_auto_snapshot:
Expand All @@ -397,7 +399,7 @@ def from_args(
cpu=cpu,
is_builder_function=True,
is_auto_snapshot=True,
_experimental_scheduler_placement=_experimental_scheduler_placement,
scheduler_placement=scheduler_placement,
)
image = _Image._from_args(
base_images={"base": image},
Expand Down Expand Up @@ -596,9 +598,7 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona
cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts),
_experimental_boost=_experimental_boost,
_experimental_scheduler=_experimental_scheduler,
_experimental_scheduler_placement=_experimental_scheduler_placement.proto
if _experimental_scheduler_placement
else None,
scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
)
request = api_pb2.FunctionCreateRequest(
app_id=resolver.app_id,
Expand Down
11 changes: 8 additions & 3 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def _new(
workdir: Optional[str] = None,
gpu: GPU_T = None,
cloud: Optional[str] = None,
region: Optional[Union[str, Sequence[str]]] = None,
cpu: Optional[float] = None,
memory: Optional[Union[int, Tuple[int, int]]] = None,
network_file_systems: Dict[Union[str, os.PathLike], _NetworkFileSystem] = {},
Expand All @@ -261,6 +262,12 @@ def _new(
raise InvalidError("network_file_systems must be a dict[str, NetworkFileSystem] where the keys are paths")
validated_network_file_systems = validate_mount_points("Network file system", network_file_systems)

scheduler_placement: Optional[SchedulerPlacement] = _experimental_scheduler_placement
if region:
if scheduler_placement:
raise InvalidError("`region` and `_experimental_scheduler_placement` cannot be used together")
scheduler_placement = SchedulerPlacement(region=region)

# Validate volumes
validated_volumes = validate_volumes(volumes)
cloud_bucket_mounts = [(k, v) for k, v in validated_volumes if isinstance(v, _CloudBucketMount)]
Expand Down Expand Up @@ -304,9 +311,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona
volume_mounts=volume_mounts,
pty_info=pty_info,
_experimental_scheduler=_experimental_scheduler,
_experimental_scheduler_placement=_experimental_scheduler_placement.proto
if _experimental_scheduler_placement
else None,
scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
)

create_req = api_pb2.SandboxCreateRequest(app_id=resolver.app_id, definition=definition)
Expand Down
12 changes: 9 additions & 3 deletions modal/scheduler_placement.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright Modal Labs 2024
from typing import Optional
from typing import Optional, Sequence, Union

from modal_proto import api_pb2

Expand All @@ -11,7 +11,7 @@ class SchedulerPlacement:

def __init__(
self,
region: Optional[str] = None,
region: Optional[Union[str, Sequence[str]]] = None,
zone: Optional[str] = None,
spot: Optional[bool] = None,
):
Expand All @@ -20,8 +20,14 @@ def __init__(
if spot is not None:
_lifecycle = "spot" if spot else "on-demand"

regions = []
if region:
if isinstance(region, str):
regions = [region]
else:
regions = list(region)
self.proto = api_pb2.SchedulerPlacement(
_region=region,
regions=regions,
_zone=zone,
_lifecycle=_lifecycle,
)
12 changes: 7 additions & 5 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -915,15 +915,17 @@ message Function {
// If set, tasks will be scheduled using the new scheduler, which also knows
// to look at fine-grained placement constraints.
bool _experimental_scheduler = 49;
optional SchedulerPlacement _experimental_scheduler_placement = 50;
optional SchedulerPlacement scheduler_placement = 50;
}

message SchedulerPlacement {
// TODO(irfansharif):
// - Fold in cloud, resource needs here too.
// - Allow specifying list of regions, zones, cloud, fallback and alternative
// - Allow specifying list of zones, cloud, fallback and alternative
// GPU types.
optional string _region = 1;
optional string _region = 1 [deprecated=true];

repeated string regions = 4;
optional string _zone = 2;
optional string _lifecycle = 3; // "on-demand" or "spot", else ignored
}
Expand Down Expand Up @@ -1007,7 +1009,7 @@ message FunctionGetResponse {
FunctionHandleMetadata handle_metadata = 2;
}


message FunctionUpdateSchedulingParamsRequest {
string function_id = 1;
uint32 warm_pool_size_override = 2;
Expand Down Expand Up @@ -1572,7 +1574,7 @@ message Sandbox {
// If set, tasks will be scheduled using the new scheduler, which also knows
// to look at fine-grained placement constraints.
bool _experimental_scheduler = 16;
optional SchedulerPlacement _experimental_scheduler_placement = 17;
optional SchedulerPlacement scheduler_placement = 17;
}

message SandboxCreateRequest {
Expand Down
50 changes: 34 additions & 16 deletions test/scheduler_placement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,45 @@
spot=False,
),
)
def f():
def f1():
pass


@app.function(
region="us-east-1",
)
def f2():
pass


@app.function(
region=["us-east-1", "us-west-2"],
)
def f3():
pass


def test_fn_scheduler_placement(servicer, client):
with app.run(client=client):
assert len(servicer.app_functions) == 1
fn = servicer.app_functions["fu-1"]
assert fn._experimental_scheduler
assert fn._experimental_scheduler_placement == api_pb2.SchedulerPlacement(
_region="us-east-1",
assert len(servicer.app_functions) == 3
fn1 = servicer.app_functions["fu-1"] # f1
assert fn1._experimental_scheduler
assert fn1.scheduler_placement == api_pb2.SchedulerPlacement(
regions=["us-east-1"],
_zone="us-east-1a",
_lifecycle="on-demand",
)

fn2 = servicer.app_functions["fu-2"] # f2
assert fn2.scheduler_placement == api_pb2.SchedulerPlacement(
regions=["us-east-1"],
)

fn3 = servicer.app_functions["fu-3"] # f3
assert fn3.scheduler_placement == api_pb2.SchedulerPlacement(
regions=["us-east-1", "us-west-2"],
)


@skip_non_linux
def test_sandbox_scheduler_placement(client, servicer):
Expand All @@ -39,19 +63,13 @@ def test_sandbox_scheduler_placement(client, servicer):
"-c",
"echo bye >&2 && sleep 1 && echo hi && exit 42",
timeout=600,
region="us-east-1",
_experimental_scheduler=True,
_experimental_scheduler_placement=SchedulerPlacement(
region="us-east-1",
zone="us-east-1a",
spot=False,
),
)

assert len(servicer.sandbox_defs) == 1
sb_def = servicer.sandbox_defs[0]
assert sb_def._experimental_scheduler
assert sb_def._experimental_scheduler_placement == api_pb2.SchedulerPlacement(
_region="us-east-1",
_zone="us-east-1a",
_lifecycle="on-demand",
assert sb_def.scheduler_placement == api_pb2.SchedulerPlacement(
regions=["us-east-1"],
)
assert sb_def._experimental_scheduler

0 comments on commit e0197f2

Please sign in to comment.