From 8c7e681b2eeb2316e1844ad9457b7818db5389b7 Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Tue, 22 Oct 2024 15:57:15 -0700 Subject: [PATCH] add ability to alias in the cloud registry and add k8s alias for kubernetes --- sky/check.py | 5 ++-- sky/clouds/aws.py | 2 +- sky/clouds/azure.py | 2 +- sky/clouds/cloud_registry.py | 48 ++++++++++++++++++++++++--------- sky/clouds/cudo.py | 2 +- sky/clouds/fluidstack.py | 2 +- sky/clouds/gcp.py | 2 +- sky/clouds/ibm.py | 2 +- sky/clouds/kubernetes.py | 2 +- sky/clouds/lambda_cloud.py | 2 +- sky/clouds/oci.py | 2 +- sky/clouds/paperspace.py | 2 +- sky/clouds/runpod.py | 2 +- sky/clouds/scp.py | 2 +- sky/clouds/vsphere.py | 2 +- sky/utils/resources_utils.py | 2 +- tests/common.py | 2 +- tests/test_optimizer_dryruns.py | 13 ++++----- 18 files changed, 60 insertions(+), 36 deletions(-) diff --git a/sky/check.py b/sky/check.py index 9ac2848733c..755b80d9ff0 100644 --- a/sky/check.py +++ b/sky/check.py @@ -65,8 +65,9 @@ def get_cloud_tuple( return repr(cloud_obj), cloud_obj def get_all_clouds(): - return tuple([repr(c) for c in sky_clouds.CLOUD_REGISTRY.values()] + - [cloudflare.SKY_CHECK_NAME]) + return tuple( + [repr(c) for c in sky_clouds.CLOUD_REGISTRY.clouds.values()] + + [cloudflare.SKY_CHECK_NAME]) if clouds is not None: cloud_list = clouds diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index a0962b17cac..927ed67ac3b 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -102,7 +102,7 @@ class AWSIdentityType(enum.Enum): SHARED_CREDENTIALS_FILE = 'shared-credentials-file' -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class AWS(clouds.Cloud): """Amazon Web Services.""" diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index adffd32ad88..a1fd605f565 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -47,7 +47,7 @@ def _run_output(cmd): return proc.stdout.decode('ascii') -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class Azure(clouds.Cloud): """Azure.""" diff --git a/sky/clouds/cloud_registry.py b/sky/clouds/cloud_registry.py index 5c4b10b9fd4..f270306c2e0 100644 --- a/sky/clouds/cloud_registry.py +++ b/sky/clouds/cloud_registry.py @@ -1,7 +1,7 @@ """Clouds need to be registered in CLOUD_REGISTRY to be discovered""" import typing -from typing import Optional, Type +from typing import Dict, Optional, Type, List, Callable from sky.utils import ux_utils @@ -9,23 +9,45 @@ from sky.clouds import cloud -class _CloudRegistry(dict): +class _CloudRegistry: """Registry of clouds.""" + def __init__(self): + self.clouds: Dict[str, 'cloud.Cloud'] = {} + self.aliases: Dict[str, str] = {} + def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']: if name is None: return None - if name.lower() not in self: - with ux_utils.print_exception_no_traceback(): - raise ValueError(f'Cloud {name!r} is not a valid cloud among ' - f'{list(self.keys())}') - return self.get(name.lower()) - - def register(self, cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']: - name = cloud_cls.__name__.lower() - assert name not in self, f'{name} already registered' - self[name] = cloud_cls() - return cloud_cls + search_name = name.lower() + + if search_name in self.clouds: + return self.clouds[search_name] + + if search_name in self.aliases: + return self.clouds[self.aliases[search_name]] + + with ux_utils.print_exception_no_traceback(): + raise ValueError(f'Cloud {name!r} is not a valid cloud among ' + f'{list(self.clouds.keys())}') + + def register( + self, + aliases: List[str] = [] + ) -> Callable[[Type['cloud.Cloud']], Type['cloud.Cloud']]: + + def _register(cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']: + name = cloud_cls.__name__.lower() + assert name not in self.clouds, f'{name} already registered' + self.clouds[name] = cloud_cls() + + for alias in aliases: + assert alias not in self.aliases, f'alias {alias} already registered' + self.aliases[alias] = name + + return cloud_cls + + return _register CLOUD_REGISTRY: _CloudRegistry = _CloudRegistry() diff --git a/sky/clouds/cudo.py b/sky/clouds/cudo.py index 4dca442fa01..558724df002 100644 --- a/sky/clouds/cudo.py +++ b/sky/clouds/cudo.py @@ -28,7 +28,7 @@ def _run_output(cmd): return proc.stdout.decode('ascii') -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class Cudo(clouds.Cloud): """Cudo Compute""" _REPR = 'Cudo' diff --git a/sky/clouds/fluidstack.py b/sky/clouds/fluidstack.py index 473fceabbe3..2e50e2096b3 100644 --- a/sky/clouds/fluidstack.py +++ b/sky/clouds/fluidstack.py @@ -22,7 +22,7 @@ from sky import resources as resources_lib -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class Fluidstack(clouds.Cloud): """FluidStack GPU Cloud.""" diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 1b70abf914d..f8fac6ea508 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -133,7 +133,7 @@ class GCPIdentityType(enum.Enum): SHARED_CREDENTIALS_FILE = '' -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class GCP(clouds.Cloud): """Google Cloud Platform.""" diff --git a/sky/clouds/ibm.py b/sky/clouds/ibm.py index b78cc4287c0..c59fc470e4f 100644 --- a/sky/clouds/ibm.py +++ b/sky/clouds/ibm.py @@ -22,7 +22,7 @@ logger = sky_logging.init_logger(__name__) -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class IBM(clouds.Cloud): """IBM Web Services.""" diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index da85246e9ea..d2a01639870 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -33,7 +33,7 @@ _SKYPILOT_SYSTEM_NAMESPACE = 'skypilot-system' -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register(aliases=["k8s"]) class Kubernetes(clouds.Cloud): """Kubernetes.""" diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py index 0201f4f76ad..78af1e9bce2 100644 --- a/sky/clouds/lambda_cloud.py +++ b/sky/clouds/lambda_cloud.py @@ -21,7 +21,7 @@ ] -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class Lambda(clouds.Cloud): """Lambda Labs GPU Cloud.""" diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index 810e43fe3b5..4b1ff49b3cc 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -45,7 +45,7 @@ _tenancy_prefix: Optional[str] = None -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class OCI(clouds.Cloud): """OCI: Oracle Cloud Infrastructure """ diff --git a/sky/clouds/paperspace.py b/sky/clouds/paperspace.py index 4c4fa1d695a..d7b5fa716c3 100644 --- a/sky/clouds/paperspace.py +++ b/sky/clouds/paperspace.py @@ -20,7 +20,7 @@ ] -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class Paperspace(clouds.Cloud): """Paperspace GPU Cloud""" diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 6cfdf11c6b4..1a38d5d266f 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -16,7 +16,7 @@ ] -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class RunPod(clouds.Cloud): """ RunPod GPU Cloud diff --git a/sky/clouds/scp.py b/sky/clouds/scp.py index 17a54ce1607..67975ee8418 100644 --- a/sky/clouds/scp.py +++ b/sky/clouds/scp.py @@ -30,7 +30,7 @@ _SCP_MAX_DISK_SIZE_GB = 300 -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class SCP(clouds.Cloud): """SCP Cloud.""" diff --git a/sky/clouds/vsphere.py b/sky/clouds/vsphere.py index 7cf56b46a8d..4306fe6616a 100644 --- a/sky/clouds/vsphere.py +++ b/sky/clouds/vsphere.py @@ -25,7 +25,7 @@ ] -@clouds.CLOUD_REGISTRY.register +@clouds.CLOUD_REGISTRY.register() class Vsphere(clouds.Cloud): """Vsphere cloud""" diff --git a/sky/utils/resources_utils.py b/sky/utils/resources_utils.py index 72aa5ac05d3..e9d47696907 100644 --- a/sky/utils/resources_utils.py +++ b/sky/utils/resources_utils.py @@ -191,7 +191,7 @@ def need_to_query_reservations() -> bool: This is useful to skip the potentially expensive reservation query for clouds that do not use reservations. """ - for cloud_str in cloud_registry.CLOUD_REGISTRY.keys(): + for cloud_str in cloud_registry.CLOUD_REGISTRY.clouds.keys(): cloud_specific_reservations = skypilot_config.get_nested( (cloud_str, 'specific_reservations'), None) cloud_prioritize_reservations = skypilot_config.get_nested( diff --git a/tests/common.py b/tests/common.py index c6f08588d99..856f3da1fb9 100644 --- a/tests/common.py +++ b/tests/common.py @@ -20,7 +20,7 @@ def enable_all_clouds_in_monkeypatch( # when the optimizer tries calling it to update enabled_clouds, it does not # raise exceptions. if enabled_clouds is None: - enabled_clouds = list(clouds.CLOUD_REGISTRY.values()) + enabled_clouds = list(clouds.CLOUD_REGISTRY.clouds.values()) monkeypatch.setattr( 'sky.check.get_cached_enabled_clouds_or_refresh', lambda *_args, **_kwargs: enabled_clouds, diff --git a/tests/test_optimizer_dryruns.py b/tests/test_optimizer_dryruns.py index f1af9a0d9ee..01a99aa6fff 100644 --- a/tests/test_optimizer_dryruns.py +++ b/tests/test_optimizer_dryruns.py @@ -646,7 +646,7 @@ def _test_optimize_speed(resources: sky.Resources): def test_optimize_speed(enable_all_clouds, monkeypatch): _test_optimize_speed(sky.Resources(cpus=4)) - for cloud in clouds.CLOUD_REGISTRY.values(): + for cloud in clouds.CLOUD_REGISTRY.clouds.values(): _test_optimize_speed(sky.Resources(cloud, cpus='4+')) _test_optimize_speed(sky.Resources(cpus='4+', memory='4+')) _test_optimize_speed( @@ -733,7 +733,7 @@ def test_ordered_resources(enable_all_clouds, monkeypatch): def test_disk_tier_mismatch(enable_all_clouds): - for cloud in clouds.CLOUD_REGISTRY.values(): + for cloud in clouds.CLOUD_REGISTRY.clouds.values(): for tier in cloud._SUPPORTED_DISK_TIERS: sky.Resources(cloud=cloud, disk_tier=tier) for unsupported_tier in (set(resources_utils.DiskTier) - @@ -756,20 +756,20 @@ def _get_all_candidate_cloud(r: sky.Resources) -> Set[clouds.Cloud]: best_tier_resources = sky.Resources(disk_tier=resources_utils.DiskTier.BEST) best_tier_candidates = _get_all_candidate_cloud(best_tier_resources) assert best_tier_candidates == set( - clouds.CLOUD_REGISTRY.values()), best_tier_candidates + clouds.CLOUD_REGISTRY.clouds.values()), best_tier_candidates # Only AWS, GCP, Azure, OCI supports LOW disk tier. low_tier_resources = sky.Resources(disk_tier=resources_utils.DiskTier.LOW) low_tier_candidates = _get_all_candidate_cloud(low_tier_resources) assert low_tier_candidates == set( - map(clouds.CLOUD_REGISTRY.get, + map(clouds.CLOUD_REGISTRY.clouds.get, ['aws', 'gcp', 'azure', 'oci'])), low_tier_candidates # Only AWS, GCP, Azure, OCI supports HIGH disk tier. high_tier_resources = sky.Resources(disk_tier=resources_utils.DiskTier.HIGH) high_tier_candidates = _get_all_candidate_cloud(high_tier_resources) assert high_tier_candidates == set( - map(clouds.CLOUD_REGISTRY.get, + map(clouds.CLOUD_REGISTRY.clouds.get, ['aws', 'gcp', 'azure', 'oci'])), high_tier_candidates # Only AWS, GCP supports ULTRA disk tier. @@ -777,4 +777,5 @@ def _get_all_candidate_cloud(r: sky.Resources) -> Set[clouds.Cloud]: disk_tier=resources_utils.DiskTier.ULTRA) ultra_tier_candidates = _get_all_candidate_cloud(ultra_tier_resources) assert ultra_tier_candidates == set( - map(clouds.CLOUD_REGISTRY.get, ['aws', 'gcp'])), ultra_tier_candidates + map(clouds.CLOUD_REGISTRY.clouds.get, + ['aws', 'gcp'])), ultra_tier_candidates