Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate invalid object names and standardize error message #1777

Merged
merged 14 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions modal/_utils/app_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright Modal Labs 2022
import re

# https://www.rfc-editor.org/rfc/rfc1035
subdomain_regex = re.compile("^(?![0-9]+$)(?!-)[a-z0-9-]{,63}(?<!-)$")


def is_valid_subdomain_label(label: str):
return subdomain_regex.match(label) is not None


def replace_invalid_subdomain_chars(label: str):
return re.sub("[^a-z0-9-]", "-", label.lower())


def is_valid_app_name(name: str):
return len(name) <= 64 and re.match("^[a-zA-Z0-9-_.]+$", name) is not None
# Copyright Modal Labs 2024
# Temporary shim as we use this in the server
from .name_utils import * # noqa
37 changes: 37 additions & 0 deletions modal/_utils/name_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright Modal Labs 2022
import re

from ..exception import InvalidError, deprecation_warning

# https://www.rfc-editor.org/rfc/rfc1035
subdomain_regex = re.compile("^(?![0-9]+$)(?!-)[a-z0-9-]{,63}(?<!-)$")


def is_valid_subdomain_label(label: str) -> bool:
return subdomain_regex.match(label) is not None


def replace_invalid_subdomain_chars(label: str) -> str:
return re.sub("[^a-z0-9-]", "-", label.lower())


def is_valid_object_name(name: str) -> bool:
return len(name) <= 64 and re.match("^[a-zA-Z0-9-_.]+$", name) is not None


def check_object_name(name: str, object_type: str, warn: bool = False) -> None:
message = (
f"Invalid {object_type} name: '{name}'."
"\n\nNames may contain only alphanumeric characters, dashes, periods, and underscores,"
" and must be shorter than 64 characters."
)
if warn:
message += "\n\nThis will become an error in the future. Please rename your object to preserve access to it."
if not is_valid_object_name(name):
if warn:
deprecation_warning((2024, 4, 30), message, show_source=False)
else:
raise InvalidError(message)


is_valid_app_name = is_valid_object_name # TODO becaue we use the former in the server
2 changes: 2 additions & 0 deletions modal/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._serialization import deserialize, serialize
from ._utils.async_utils import TaskContext, synchronize_api
from ._utils.grpc_utils import retry_transient_errors, unary_stream
from ._utils.name_utils import check_object_name
from .client import _Client
from .config import logger
from .exception import deprecation_warning
Expand Down Expand Up @@ -134,6 +135,7 @@ def from_name(
dict[123] = 456
```
"""
check_object_name(label, "Dict", warn=True)

async def _load(self: _Dict, resolver: Resolver, existing_object_id: Optional[str]):
serialized = _serialize_dict(data if data is not None else {})
Expand Down
2 changes: 2 additions & 0 deletions modal/mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._utils.async_utils import synchronize_api
from ._utils.blob_utils import FileUploadSpec, blob_upload_file, get_file_upload_spec_from_path
from ._utils.grpc_utils import retry_transient_errors
from ._utils.name_utils import check_object_name
from ._utils.package_utils import get_module_mount_info
from .client import _Client
from .config import config, logger
Expand Down Expand Up @@ -584,6 +585,7 @@ async def _deploy(
environment_name: Optional[str] = None,
client: Optional[_Client] = None,
) -> "_Mount":
check_object_name(deployment_name, "Mount")
self._deployment_name = deployment_name
self._namespace = namespace
self._environment_name = environment_name
Expand Down
2 changes: 2 additions & 0 deletions modal/network_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._utils.blob_utils import LARGE_FILE_LIMIT, blob_iter, blob_upload_file
from ._utils.grpc_utils import retry_transient_errors, unary_stream
from ._utils.hash_utils import get_sha256_hex
from ._utils.name_utils import check_object_name
from .client import _Client
from .exception import deprecation_warning
from .object import (
Expand Down Expand Up @@ -133,6 +134,7 @@ def f():
pass
```
"""
check_object_name(label, "NetworkFileSystem", warn=True)

async def _load(self: _NetworkFileSystem, resolver: Resolver, existing_object_id: Optional[str]):
req = api_pb2.SharedVolumeGetOrCreateRequest(
Expand Down
2 changes: 2 additions & 0 deletions modal/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ._serialization import deserialize, serialize
from ._utils.async_utils import TaskContext, synchronize_api, warn_if_generator_is_not_consumed
from ._utils.grpc_utils import retry_transient_errors
from ._utils.name_utils import check_object_name
from .client import _Client
from .exception import InvalidError, deprecation_warning
from .object import EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, _get_environment_name, _Object, live_method, live_method_gen
Expand Down Expand Up @@ -163,6 +164,7 @@ def from_name(
queue.put(123)
```
"""
check_object_name(label, "Queue", warn=True)

async def _load(self: _Queue, resolver: Resolver, existing_object_id: Optional[str]):
req = api_pb2.QueueGetOrCreateRequest(
Expand Down
9 changes: 3 additions & 6 deletions modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from ._pty import get_pty_info
from ._resolver import Resolver
from ._sandbox_shell import connect_to_sandbox
from ._utils.app_utils import is_valid_app_name
from ._utils.async_utils import TaskContext, synchronize_api
from ._utils.grpc_utils import retry_transient_errors
from ._utils.name_utils import check_object_name
from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
from .config import config, logger
from .exception import ExecutionError, InteractiveTimeoutError, InvalidError, _CliUserExecutionError
Expand Down Expand Up @@ -397,11 +397,8 @@ async def _deploy_app(
"or\n"
'app = App("some-name")'
)

if not is_valid_app_name(name):
raise InvalidError(
f"Invalid app name {name}. App names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length. "
)
else:
check_object_name(name, "App")

if client is None:
client = await _Client.from_env()
Expand Down
5 changes: 5 additions & 0 deletions modal/secret.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._resolver import Resolver
from ._utils.async_utils import synchronize_api
from ._utils.grpc_utils import retry_transient_errors
from ._utils.name_utils import check_object_name
from .client import _Client
from .exception import InvalidError, NotFoundError
from .execution_context import is_local
Expand Down Expand Up @@ -169,6 +170,10 @@ def run():
...
```
"""
# Unlike other objects, you can't create secrets through this method, but we will still
# warn here so that people get the message when they *look up* secrets with illegal names.
# We can just remove the check after the deprecation period, instead of converting to an error.
check_object_name(label, "Secret", warn=True)

async def _load(self: _Secret, resolver: Resolver, existing_object_id: Optional[str]):
req = api_pb2.SecretGetOrCreateRequest(
Expand Down
2 changes: 2 additions & 0 deletions modal/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_file_upload_spec_from_path,
)
from ._utils.grpc_utils import retry_transient_errors, unary_stream
from ._utils.name_utils import check_object_name
from .client import _Client
from .config import logger
from .exception import deprecation_error
Expand Down Expand Up @@ -181,6 +182,7 @@ def f():
pass
```
"""
check_object_name(label, "Volume", warn=True)

async def _load(self: _Volume, resolver: Resolver, existing_object_id: Optional[str]):
req = api_pb2.VolumeGetOrCreateRequest(
Expand Down
4 changes: 2 additions & 2 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def type_check(ctx):
pyright_allowlist = [
"modal/functions.py",
"modal/_utils/__init__.py",
"modal/_utils/app_utils.py",
"modal/_utils/async_utils.py",
"modal/_utils/grpc_testing.py",
"modal/_utils/hash_utils.py",
"modal/_utils/http_utils.py",
"modal/_utils/name_utils.py",
"modal/_utils/logger.py",
"modal/_utils/mount_utils.py",
"modal/_utils/package_utils.py",
Expand Down Expand Up @@ -312,7 +312,7 @@ def visit_Attribute(self, node):
def visit_Call(self, node):
func_name_to_level = {
"deprecation_warning": "[yellow]warning[/yellow]",
"deprecation_error": "[red]error[/red]"
"deprecation_error": "[red]error[/red]",
}
if isinstance(node.func, ast.Name) and node.func.id in func_name_to_level:
depr_date = date(*(elt.n for elt in node.args[0].elts))
Expand Down
8 changes: 7 additions & 1 deletion test/dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

from modal import Dict
from modal.exception import NotFoundError
from modal.exception import DeprecationError, NotFoundError


def test_dict_app(servicer, client):
Expand Down Expand Up @@ -49,3 +49,9 @@ def test_dict_lazy_hydrate_named(set_env_client, servicer):
d["foo"] = 42
assert d["foo"] == 42
assert len(ctx.get_requests("DictGetOrCreate")) == 1 # just sanity check that object is only hydrated once...


@pytest.mark.parametrize("name", ["has space", "has/slash", "a" * 65])
def test_invalid_name(servicer, client, name):
with pytest.raises(DeprecationError, match="Invalid Dict name"):
Dict.lookup(name)
6 changes: 6 additions & 0 deletions test/network_file_system_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,9 @@ def test_nfs_lazy_hydration_from_name(set_env_client):
nfs = modal.NetworkFileSystem.from_name("nfs", create_if_missing=True)
bio = BytesIO(b"content")
nfs.write_file("blah", bio)


@pytest.mark.parametrize("name", ["has space", "has/slash", "a" * 65])
def test_invalid_name(servicer, client, name):
with pytest.raises(DeprecationError, match="Invalid NetworkFileSystem name"):
modal.NetworkFileSystem.lookup(name)
8 changes: 7 additions & 1 deletion test/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time

from modal import Queue
from modal.exception import NotFoundError
from modal.exception import DeprecationError, NotFoundError

from .supports.skip import skip_macos, skip_windows

Expand Down Expand Up @@ -113,3 +113,9 @@ def test_queue_lazy_hydrate_from_name(set_env_client):
q = Queue.from_name("foo", create_if_missing=True)
q.put(123)
assert q.get() == 123


@pytest.mark.parametrize("name", ["has space", "has/slash", "a" * 65])
def test_invalid_name(servicer, client, name):
with pytest.raises(DeprecationError, match="Invalid Queue name"):
Queue.lookup(name)
19 changes: 12 additions & 7 deletions test/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import io
import pytest

from modal._utils.app_utils import is_valid_app_name, is_valid_subdomain_label
from modal._utils.blob_utils import BytesIOSegmentPayload
from modal._utils.name_utils import check_object_name, is_valid_object_name, is_valid_subdomain_label
from modal.exception import DeprecationError, InvalidError


def test_subdomain_label():
Expand All @@ -16,12 +17,16 @@ def test_subdomain_label():
assert not is_valid_subdomain_label("ban/ana")


def test_app_name():
assert is_valid_app_name("baNaNa")
assert is_valid_app_name("foo-123_456")
assert is_valid_app_name("a" * 64)
assert not is_valid_app_name("hello world")
assert not is_valid_app_name("a" * 65)
def test_object_name():
assert is_valid_object_name("baNaNa")
assert is_valid_object_name("foo-123_456")
assert is_valid_object_name("a" * 64)
assert not is_valid_object_name("hello world")
assert not is_valid_object_name("a" * 65)
with pytest.raises(InvalidError, match="Invalid Volume name: 'foo/bar'"):
check_object_name("foo/bar", "Volume")
with pytest.warns(DeprecationError, match="Invalid Volume name: 'foo/bar'"):
check_object_name("foo/bar", "Volume", warn=True)


@pytest.mark.asyncio
Expand Down
6 changes: 6 additions & 0 deletions test/volume_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,9 @@ async def test_open_files_error_annotation(tmp_path):
proc.kill()
await proc.wait()
assert _open_files_error_annotation(tmp_path) is None


@pytest.mark.parametrize("name", ["has space", "has/slash", "a" * 65])
def test_invalid_name(servicer, client, name):
with pytest.raises(DeprecationError, match="Invalid Volume name"):
modal.Volume.lookup(name)
Loading