Skip to content

Commit

Permalink
[aws] catch the case where get_credentials is None (#4613)
Browse files Browse the repository at this point in the history
* [aws] catch the case where get_credentials is None

* fix test

* speed up sky check

* Update sky/adaptors/aws.py

Co-authored-by: Romil Bhardwaj <[email protected]>

---------

Co-authored-by: Romil Bhardwaj <[email protected]>
  • Loading branch information
cg505 and romilbhardwaj authored Jan 29, 2025
1 parent e5c999b commit a80208f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 9 deletions.
26 changes: 21 additions & 5 deletions sky/adaptors/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,17 @@ def _create_aws_object(creation_fn_or_cls: Callable[[], Any],
# The LRU cache needs to be thread-local to avoid multiple threads sharing the
# same session object, which is not guaranteed to be thread-safe.
@_thread_local_lru_cache()
def session():
def session(check_credentials: bool = True):
"""Create an AWS session."""
return _create_aws_object(boto3.session.Session, 'session')
s = _create_aws_object(boto3.session.Session, 'session')
if check_credentials and s.get_credentials() is None:
# s.get_credentials() can be None if there are actually no credentials,
# or if we fail to get credentials from IMDS (e.g. due to throttling).
# Technically, it could be okay to have no credentials, as certain AWS
# APIs don't actually need them. But afaik everything we use AWS for
# needs credentials.
raise botocore_exceptions().NoCredentialsError()
return s


# Avoid caching the resource/client objects. If we are using the assumed role,
Expand All @@ -149,11 +157,15 @@ def resource(service_name: str, **kwargs):
config = botocore_config().Config(
retries={'max_attempts': max_attempts})
kwargs['config'] = config

check_credentials = kwargs.pop('check_credentials', True)

# Need to use the client retrieved from the per-thread session to avoid
# thread-safety issues (Directly creating the client with boto3.resource()
# is not thread-safe). Reference: https://stackoverflow.com/a/59635814
return _create_aws_object(
lambda: session().resource(service_name, **kwargs), 'resource')
lambda: session(check_credentials=check_credentials).resource(
service_name, **kwargs), 'resource')


def client(service_name: str, **kwargs):
Expand All @@ -164,12 +176,16 @@ def client(service_name: str, **kwargs):
kwargs: Other options.
"""
_assert_kwargs_builtin_type(kwargs)

check_credentials = kwargs.pop('check_credentials', True)

# Need to use the client retrieved from the per-thread session to avoid
# thread-safety issues (Directly creating the client with boto3.client() is
# not thread-safe). Reference: https://stackoverflow.com/a/59635814

return _create_aws_object(lambda: session().client(service_name, **kwargs),
'client')
return _create_aws_object(
lambda: session(check_credentials=check_credentials).client(
service_name, **kwargs), 'client')


@common.load_lazy_modules(modules=_LAZY_MODULES)
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def _aws_configure_list(cls) -> Optional[bytes]:
@functools.lru_cache(maxsize=1) # Cache since getting identity is slow.
def _sts_get_caller_identity(cls) -> Optional[List[List[str]]]:
try:
sts = aws.client('sts')
sts = aws.client('sts', check_credentials=False)
# The caller identity contains 3 fields: UserId, Account, Arn.
# 1. 'UserId' is unique across all AWS entity, which looks like
# "AROADBQP57FF2AEXAMPLE:role-session-name"
Expand Down
5 changes: 3 additions & 2 deletions sky/provision/aws/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
# https://aws.amazon.com/ec2/pricing/on-demand/#Data_Transfer_within_the_same_AWS_Region


def _default_ec2_resource(region: str) -> Any:
def _default_ec2_resource(region: str, check_credentials: bool = True) -> Any:
if not hasattr(aws, 'version'):
# For backward compatibility, reload the module if the aws module was
# imported before and stale. Used for, e.g., a live jobs controller
Expand Down Expand Up @@ -95,7 +95,8 @@ def _default_ec2_resource(region: str) -> Any:
importlib.reload(aws)
return aws.resource('ec2',
region_name=region,
max_attempts=BOTO_MAX_RETRIES)
max_attempts=BOTO_MAX_RETRIES,
check_credentials=check_credentials)


def _cluster_name_filter(cluster_name_on_cloud: str) -> List[Dict[str, Any]]:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/test_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def test_aws_adaptor_resources_memory_leakage():
timeout=1)[0]
total_num = int(1e3)
for i in range(total_num):
instance._default_ec2_resource(aws_regions[i % len(aws_regions)])
instance._default_ec2_resource(aws_regions[i % len(aws_regions)],
check_credentials=False)
if math.log10(i + 1).is_integer():
print(i)
mem_usage_after = memory_profiler.memory_usage(os.getpid(),
Expand Down

0 comments on commit a80208f

Please sign in to comment.