diff --git a/modal/_utils/app_utils.py b/modal/_utils/app_utils.py index a1bba07bc..c2a5eb25e 100644 --- a/modal/_utils/app_utils.py +++ b/modal/_utils/app_utils.py @@ -1,17 +1,3 @@ -# Copyright Modal Labs 2022 -import re - -# https://www.rfc-editor.org/rfc/rfc1035 -subdomain_regex = re.compile("^(?![0-9]+$)(?!-)[a-z0-9-]{,63}(? bool: + return subdomain_regex.match(label) is not None + + +def replace_invalid_subdomain_chars(label: str) -> str: + return re.sub("[^a-z0-9-]", "-", label.lower()) + + +def is_valid_object_name(name: str) -> bool: + return len(name) <= 64 and re.match("^[a-zA-Z0-9-_.]+$", name) is not None + + +def check_object_name(name: str, object_type: str, warn: bool = False) -> None: + message = ( + f"Invalid {object_type} name: '{name}'." + "\n\nNames may contain only alphanumeric characters, dashes, periods, and underscores," + " and must be shorter than 64 characters." + ) + if warn: + message += "\n\nThis will become an error in the future. Please rename your object to preserve access to it." + if not is_valid_object_name(name): + if warn: + deprecation_warning((2024, 4, 30), message, show_source=False) + else: + raise InvalidError(message) + + +is_valid_app_name = is_valid_object_name # TODO becaue we use the former in the server diff --git a/modal/dict.py b/modal/dict.py index 6c4c7cfdb..6b8e01d55 100644 --- a/modal/dict.py +++ b/modal/dict.py @@ -9,6 +9,7 @@ from ._serialization import deserialize, serialize from ._utils.async_utils import TaskContext, synchronize_api from ._utils.grpc_utils import retry_transient_errors, unary_stream +from ._utils.name_utils import check_object_name from .client import _Client from .config import logger from .exception import deprecation_warning @@ -134,6 +135,7 @@ def from_name( dict[123] = 456 ``` """ + check_object_name(label, "Dict", warn=True) async def _load(self: _Dict, resolver: Resolver, existing_object_id: Optional[str]): serialized = _serialize_dict(data if data is not None else {}) diff --git a/modal/mount.py b/modal/mount.py index 8c2ca998b..0734f0210 100644 --- a/modal/mount.py +++ b/modal/mount.py @@ -23,6 +23,7 @@ from ._utils.async_utils import synchronize_api from ._utils.blob_utils import FileUploadSpec, blob_upload_file, get_file_upload_spec_from_path from ._utils.grpc_utils import retry_transient_errors +from ._utils.name_utils import check_object_name from ._utils.package_utils import get_module_mount_info from .client import _Client from .config import config, logger @@ -584,6 +585,7 @@ async def _deploy( environment_name: Optional[str] = None, client: Optional[_Client] = None, ) -> "_Mount": + check_object_name(deployment_name, "Mount") self._deployment_name = deployment_name self._namespace = namespace self._environment_name = environment_name diff --git a/modal/network_file_system.py b/modal/network_file_system.py index 7aa46d00d..b1255374b 100644 --- a/modal/network_file_system.py +++ b/modal/network_file_system.py @@ -16,6 +16,7 @@ from ._utils.blob_utils import LARGE_FILE_LIMIT, blob_iter, blob_upload_file from ._utils.grpc_utils import retry_transient_errors, unary_stream from ._utils.hash_utils import get_sha256_hex +from ._utils.name_utils import check_object_name from .client import _Client from .exception import deprecation_warning from .object import ( @@ -133,6 +134,7 @@ def f(): pass ``` """ + check_object_name(label, "NetworkFileSystem", warn=True) async def _load(self: _NetworkFileSystem, resolver: Resolver, existing_object_id: Optional[str]): req = api_pb2.SharedVolumeGetOrCreateRequest( diff --git a/modal/queue.py b/modal/queue.py index f287e8127..0dfa6afcb 100644 --- a/modal/queue.py +++ b/modal/queue.py @@ -13,6 +13,7 @@ from ._serialization import deserialize, serialize from ._utils.async_utils import TaskContext, synchronize_api, warn_if_generator_is_not_consumed from ._utils.grpc_utils import retry_transient_errors +from ._utils.name_utils import check_object_name from .client import _Client from .exception import InvalidError, deprecation_warning from .object import EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, _get_environment_name, _Object, live_method, live_method_gen @@ -163,6 +164,7 @@ def from_name( queue.put(123) ``` """ + check_object_name(label, "Queue", warn=True) async def _load(self: _Queue, resolver: Resolver, existing_object_id: Optional[str]): req = api_pb2.QueueGetOrCreateRequest( diff --git a/modal/runner.py b/modal/runner.py index eecc3d09f..704e0452b 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -15,9 +15,9 @@ from ._pty import get_pty_info from ._resolver import Resolver from ._sandbox_shell import connect_to_sandbox -from ._utils.app_utils import is_valid_app_name from ._utils.async_utils import TaskContext, synchronize_api from ._utils.grpc_utils import retry_transient_errors +from ._utils.name_utils import check_object_name from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client from .config import config, logger from .exception import ExecutionError, InteractiveTimeoutError, InvalidError, _CliUserExecutionError @@ -397,11 +397,8 @@ async def _deploy_app( "or\n" 'app = App("some-name")' ) - - if not is_valid_app_name(name): - raise InvalidError( - f"Invalid app name {name}. App names may only contain alphanumeric characters, dashes, periods, and underscores, and must be less than 64 characters in length. " - ) + else: + check_object_name(name, "App") if client is None: client = await _Client.from_env() diff --git a/modal/secret.py b/modal/secret.py index 1099b1b67..29bd1391e 100644 --- a/modal/secret.py +++ b/modal/secret.py @@ -9,6 +9,7 @@ from ._resolver import Resolver from ._utils.async_utils import synchronize_api from ._utils.grpc_utils import retry_transient_errors +from ._utils.name_utils import check_object_name from .client import _Client from .exception import InvalidError, NotFoundError from .execution_context import is_local @@ -169,6 +170,10 @@ def run(): ... ``` """ + # Unlike other objects, you can't create secrets through this method, but we will still + # warn here so that people get the message when they *look up* secrets with illegal names. + # We can just remove the check after the deprecation period, instead of converting to an error. + check_object_name(label, "Secret", warn=True) async def _load(self: _Secret, resolver: Resolver, existing_object_id: Optional[str]): req = api_pb2.SecretGetOrCreateRequest( diff --git a/modal/volume.py b/modal/volume.py index 23d5bad3f..ac4667916 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -39,6 +39,7 @@ get_file_upload_spec_from_path, ) from ._utils.grpc_utils import retry_transient_errors, unary_stream +from ._utils.name_utils import check_object_name from .client import _Client from .config import logger from .exception import deprecation_error @@ -181,6 +182,7 @@ def f(): pass ``` """ + check_object_name(label, "Volume", warn=True) async def _load(self: _Volume, resolver: Resolver, existing_object_id: Optional[str]): req = api_pb2.VolumeGetOrCreateRequest( diff --git a/tasks.py b/tasks.py index f62e79e14..4a9bdff37 100644 --- a/tasks.py +++ b/tasks.py @@ -47,11 +47,11 @@ def type_check(ctx): pyright_allowlist = [ "modal/functions.py", "modal/_utils/__init__.py", - "modal/_utils/app_utils.py", "modal/_utils/async_utils.py", "modal/_utils/grpc_testing.py", "modal/_utils/hash_utils.py", "modal/_utils/http_utils.py", + "modal/_utils/name_utils.py", "modal/_utils/logger.py", "modal/_utils/mount_utils.py", "modal/_utils/package_utils.py", @@ -312,7 +312,7 @@ def visit_Attribute(self, node): def visit_Call(self, node): func_name_to_level = { "deprecation_warning": "[yellow]warning[/yellow]", - "deprecation_error": "[red]error[/red]" + "deprecation_error": "[red]error[/red]", } if isinstance(node.func, ast.Name) and node.func.id in func_name_to_level: depr_date = date(*(elt.n for elt in node.args[0].elts)) diff --git a/test/dict_test.py b/test/dict_test.py index 4959beb87..ed5ed80a7 100644 --- a/test/dict_test.py +++ b/test/dict_test.py @@ -3,7 +3,7 @@ import time from modal import Dict -from modal.exception import NotFoundError +from modal.exception import DeprecationError, NotFoundError def test_dict_app(servicer, client): @@ -49,3 +49,9 @@ def test_dict_lazy_hydrate_named(set_env_client, servicer): d["foo"] = 42 assert d["foo"] == 42 assert len(ctx.get_requests("DictGetOrCreate")) == 1 # just sanity check that object is only hydrated once... + + +@pytest.mark.parametrize("name", ["has space", "has/slash", "a" * 65]) +def test_invalid_name(servicer, client, name): + with pytest.raises(DeprecationError, match="Invalid Dict name"): + Dict.lookup(name) diff --git a/test/network_file_system_test.py b/test/network_file_system_test.py index b8fb6bdb9..ee9ffb8ab 100644 --- a/test/network_file_system_test.py +++ b/test/network_file_system_test.py @@ -186,3 +186,9 @@ def test_nfs_lazy_hydration_from_name(set_env_client): nfs = modal.NetworkFileSystem.from_name("nfs", create_if_missing=True) bio = BytesIO(b"content") nfs.write_file("blah", bio) + + +@pytest.mark.parametrize("name", ["has space", "has/slash", "a" * 65]) +def test_invalid_name(servicer, client, name): + with pytest.raises(DeprecationError, match="Invalid NetworkFileSystem name"): + modal.NetworkFileSystem.lookup(name) diff --git a/test/queue_test.py b/test/queue_test.py index 27978c3ae..d3100c537 100644 --- a/test/queue_test.py +++ b/test/queue_test.py @@ -4,7 +4,7 @@ import time from modal import Queue -from modal.exception import NotFoundError +from modal.exception import DeprecationError, NotFoundError from .supports.skip import skip_macos, skip_windows @@ -113,3 +113,9 @@ def test_queue_lazy_hydrate_from_name(set_env_client): q = Queue.from_name("foo", create_if_missing=True) q.put(123) assert q.get() == 123 + + +@pytest.mark.parametrize("name", ["has space", "has/slash", "a" * 65]) +def test_invalid_name(servicer, client, name): + with pytest.raises(DeprecationError, match="Invalid Queue name"): + Queue.lookup(name) diff --git a/test/utils_test.py b/test/utils_test.py index 62e890150..05fbeca76 100644 --- a/test/utils_test.py +++ b/test/utils_test.py @@ -4,8 +4,9 @@ import io import pytest -from modal._utils.app_utils import is_valid_app_name, is_valid_subdomain_label from modal._utils.blob_utils import BytesIOSegmentPayload +from modal._utils.name_utils import check_object_name, is_valid_object_name, is_valid_subdomain_label +from modal.exception import DeprecationError, InvalidError def test_subdomain_label(): @@ -16,12 +17,16 @@ def test_subdomain_label(): assert not is_valid_subdomain_label("ban/ana") -def test_app_name(): - assert is_valid_app_name("baNaNa") - assert is_valid_app_name("foo-123_456") - assert is_valid_app_name("a" * 64) - assert not is_valid_app_name("hello world") - assert not is_valid_app_name("a" * 65) +def test_object_name(): + assert is_valid_object_name("baNaNa") + assert is_valid_object_name("foo-123_456") + assert is_valid_object_name("a" * 64) + assert not is_valid_object_name("hello world") + assert not is_valid_object_name("a" * 65) + with pytest.raises(InvalidError, match="Invalid Volume name: 'foo/bar'"): + check_object_name("foo/bar", "Volume") + with pytest.warns(DeprecationError, match="Invalid Volume name: 'foo/bar'"): + check_object_name("foo/bar", "Volume", warn=True) @pytest.mark.asyncio diff --git a/test/volume_test.py b/test/volume_test.py index 17064e9a1..700f9a273 100644 --- a/test/volume_test.py +++ b/test/volume_test.py @@ -395,3 +395,9 @@ async def test_open_files_error_annotation(tmp_path): proc.kill() await proc.wait() assert _open_files_error_annotation(tmp_path) is None + + +@pytest.mark.parametrize("name", ["has space", "has/slash", "a" * 65]) +def test_invalid_name(servicer, client, name): + with pytest.raises(DeprecationError, match="Invalid Volume name"): + modal.Volume.lookup(name)