Skip to content

Commit

Permalink
add ability to alias in the cloud registry and add k8s alias for kube…
Browse files Browse the repository at this point in the history
…rnetes
  • Loading branch information
cg505 committed Oct 22, 2024
1 parent 4fd94ab commit 8c7e681
Show file tree
Hide file tree
Showing 18 changed files with 60 additions and 36 deletions.
5 changes: 3 additions & 2 deletions sky/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ def get_cloud_tuple(
return repr(cloud_obj), cloud_obj

def get_all_clouds():
return tuple([repr(c) for c in sky_clouds.CLOUD_REGISTRY.values()] +
[cloudflare.SKY_CHECK_NAME])
return tuple(
[repr(c) for c in sky_clouds.CLOUD_REGISTRY.clouds.values()] +
[cloudflare.SKY_CHECK_NAME])

if clouds is not None:
cloud_list = clouds
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class AWSIdentityType(enum.Enum):
SHARED_CREDENTIALS_FILE = 'shared-credentials-file'


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class AWS(clouds.Cloud):
"""Amazon Web Services."""

Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _run_output(cmd):
return proc.stdout.decode('ascii')


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class Azure(clouds.Cloud):
"""Azure."""

Expand Down
48 changes: 35 additions & 13 deletions sky/clouds/cloud_registry.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,53 @@
"""Clouds need to be registered in CLOUD_REGISTRY to be discovered"""

import typing
from typing import Optional, Type
from typing import Dict, Optional, Type, List, Callable

from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky.clouds import cloud


class _CloudRegistry(dict):
class _CloudRegistry:
"""Registry of clouds."""

def __init__(self):
self.clouds: Dict[str, 'cloud.Cloud'] = {}
self.aliases: Dict[str, str] = {}

def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']:
if name is None:
return None
if name.lower() not in self:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Cloud {name!r} is not a valid cloud among '
f'{list(self.keys())}')
return self.get(name.lower())

def register(self, cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']:
name = cloud_cls.__name__.lower()
assert name not in self, f'{name} already registered'
self[name] = cloud_cls()
return cloud_cls
search_name = name.lower()

if search_name in self.clouds:
return self.clouds[search_name]

if search_name in self.aliases:
return self.clouds[self.aliases[search_name]]

with ux_utils.print_exception_no_traceback():
raise ValueError(f'Cloud {name!r} is not a valid cloud among '
f'{list(self.clouds.keys())}')

def register(
self,
aliases: List[str] = []
) -> Callable[[Type['cloud.Cloud']], Type['cloud.Cloud']]:

def _register(cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']:
name = cloud_cls.__name__.lower()
assert name not in self.clouds, f'{name} already registered'
self.clouds[name] = cloud_cls()

for alias in aliases:
assert alias not in self.aliases, f'alias {alias} already registered'
self.aliases[alias] = name

return cloud_cls

return _register


CLOUD_REGISTRY: _CloudRegistry = _CloudRegistry()
2 changes: 1 addition & 1 deletion sky/clouds/cudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _run_output(cmd):
return proc.stdout.decode('ascii')


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class Cudo(clouds.Cloud):
"""Cudo Compute"""
_REPR = 'Cudo'
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/fluidstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sky import resources as resources_lib


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class Fluidstack(clouds.Cloud):
"""FluidStack GPU Cloud."""

Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class GCPIdentityType(enum.Enum):
SHARED_CREDENTIALS_FILE = ''


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class GCP(clouds.Cloud):
"""Google Cloud Platform."""

Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/ibm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = sky_logging.init_logger(__name__)


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class IBM(clouds.Cloud):
"""IBM Web Services."""

Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
_SKYPILOT_SYSTEM_NAMESPACE = 'skypilot-system'


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register(aliases=["k8s"])
class Kubernetes(clouds.Cloud):
"""Kubernetes."""

Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/lambda_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
]


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class Lambda(clouds.Cloud):
"""Lambda Labs GPU Cloud."""

Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
_tenancy_prefix: Optional[str] = None


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class OCI(clouds.Cloud):
"""OCI: Oracle Cloud Infrastructure """

Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/paperspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
]


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class Paperspace(clouds.Cloud):
"""Paperspace GPU Cloud"""

Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
]


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class RunPod(clouds.Cloud):
""" RunPod GPU Cloud
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_SCP_MAX_DISK_SIZE_GB = 300


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class SCP(clouds.Cloud):
"""SCP Cloud."""

Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/vsphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
]


@clouds.CLOUD_REGISTRY.register
@clouds.CLOUD_REGISTRY.register()
class Vsphere(clouds.Cloud):
"""Vsphere cloud"""

Expand Down
2 changes: 1 addition & 1 deletion sky/utils/resources_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def need_to_query_reservations() -> bool:
This is useful to skip the potentially expensive reservation query for
clouds that do not use reservations.
"""
for cloud_str in cloud_registry.CLOUD_REGISTRY.keys():
for cloud_str in cloud_registry.CLOUD_REGISTRY.clouds.keys():
cloud_specific_reservations = skypilot_config.get_nested(
(cloud_str, 'specific_reservations'), None)
cloud_prioritize_reservations = skypilot_config.get_nested(
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def enable_all_clouds_in_monkeypatch(
# when the optimizer tries calling it to update enabled_clouds, it does not
# raise exceptions.
if enabled_clouds is None:
enabled_clouds = list(clouds.CLOUD_REGISTRY.values())
enabled_clouds = list(clouds.CLOUD_REGISTRY.clouds.values())
monkeypatch.setattr(
'sky.check.get_cached_enabled_clouds_or_refresh',
lambda *_args, **_kwargs: enabled_clouds,
Expand Down
13 changes: 7 additions & 6 deletions tests/test_optimizer_dryruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def _test_optimize_speed(resources: sky.Resources):

def test_optimize_speed(enable_all_clouds, monkeypatch):
_test_optimize_speed(sky.Resources(cpus=4))
for cloud in clouds.CLOUD_REGISTRY.values():
for cloud in clouds.CLOUD_REGISTRY.clouds.values():
_test_optimize_speed(sky.Resources(cloud, cpus='4+'))
_test_optimize_speed(sky.Resources(cpus='4+', memory='4+'))
_test_optimize_speed(
Expand Down Expand Up @@ -733,7 +733,7 @@ def test_ordered_resources(enable_all_clouds, monkeypatch):


def test_disk_tier_mismatch(enable_all_clouds):
for cloud in clouds.CLOUD_REGISTRY.values():
for cloud in clouds.CLOUD_REGISTRY.clouds.values():
for tier in cloud._SUPPORTED_DISK_TIERS:
sky.Resources(cloud=cloud, disk_tier=tier)
for unsupported_tier in (set(resources_utils.DiskTier) -
Expand All @@ -756,25 +756,26 @@ def _get_all_candidate_cloud(r: sky.Resources) -> Set[clouds.Cloud]:
best_tier_resources = sky.Resources(disk_tier=resources_utils.DiskTier.BEST)
best_tier_candidates = _get_all_candidate_cloud(best_tier_resources)
assert best_tier_candidates == set(
clouds.CLOUD_REGISTRY.values()), best_tier_candidates
clouds.CLOUD_REGISTRY.clouds.values()), best_tier_candidates

# Only AWS, GCP, Azure, OCI supports LOW disk tier.
low_tier_resources = sky.Resources(disk_tier=resources_utils.DiskTier.LOW)
low_tier_candidates = _get_all_candidate_cloud(low_tier_resources)
assert low_tier_candidates == set(
map(clouds.CLOUD_REGISTRY.get,
map(clouds.CLOUD_REGISTRY.clouds.get,
['aws', 'gcp', 'azure', 'oci'])), low_tier_candidates

# Only AWS, GCP, Azure, OCI supports HIGH disk tier.
high_tier_resources = sky.Resources(disk_tier=resources_utils.DiskTier.HIGH)
high_tier_candidates = _get_all_candidate_cloud(high_tier_resources)
assert high_tier_candidates == set(
map(clouds.CLOUD_REGISTRY.get,
map(clouds.CLOUD_REGISTRY.clouds.get,
['aws', 'gcp', 'azure', 'oci'])), high_tier_candidates

# Only AWS, GCP supports ULTRA disk tier.
ultra_tier_resources = sky.Resources(
disk_tier=resources_utils.DiskTier.ULTRA)
ultra_tier_candidates = _get_all_candidate_cloud(ultra_tier_resources)
assert ultra_tier_candidates == set(
map(clouds.CLOUD_REGISTRY.get, ['aws', 'gcp'])), ultra_tier_candidates
map(clouds.CLOUD_REGISTRY.clouds.get,
['aws', 'gcp'])), ultra_tier_candidates

0 comments on commit 8c7e681

Please sign in to comment.