Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add back the SharedTokenCacheCredential to handle token which is cached by the InteractiveBrowserCredential #603

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
87 changes: 79 additions & 8 deletions azure-quantum/azure/quantum/_authentication/_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import sys
import logging
import re
from typing import Optional
Expand All @@ -16,10 +17,13 @@
InteractiveBrowserCredential,
DeviceCodeCredential,
_internal as AzureIdentityInternals,
TokenCachePersistenceOptions,
SharedTokenCacheCredential,
_persistent_cache as AzureIdentityPersistentCache
)
from azure.quantum._constants import ConnectionConstants
from ._chained import _ChainedTokenCredential
from ._token import _TokenFileCredential
from azure.quantum._constants import ConnectionConstants

_LOGGER = logging.getLogger(__name__)
WWW_AUTHENTICATE_REGEX = re.compile(
Expand Down Expand Up @@ -61,7 +65,7 @@ def __init__(
client_id: Optional[str] = None,
tenant_id: Optional[str] = None,
authority: Optional[str] = None,
):
) -> None:
if arm_endpoint is None:
raise ValueError("arm_endpoint is mandatory parameter")
if subscription_id is None:
Expand All @@ -84,22 +88,90 @@ def _authority_or_default(self, authority: str, arm_endpoint: str):
return ConnectionConstants.DOGFOOD_AUTHORITY
return ConnectionConstants.AUTHORITY

def _initialize_credentials(self):
def _get_cache_options(self) -> Optional[TokenCachePersistenceOptions]:
"""
Returns a valid TokenCachePersistenceOptions
if the AzureIdentity Persistent Cache is accessible.
Returns None otherwise.
"""
cache_options = TokenCachePersistenceOptions(
allow_unencrypted_storage=False,
name="AzureQuantumSDK"
)
try:
# pylint: disable=protected-access
cache = AzureIdentityPersistentCache._load_persistent_cache(cache_options)
try:
# Try to get the location of the cache for
# tracing purpose.
_LOGGER.info(
"Using Azure.Identity Token Cache at %s.",
cache._persistence.get_location()
)
except: # pylint: disable=bare-except
_LOGGER.info("Using Azure.Identity Token Cache.")
return cache_options
except Exception as ex: # pylint: disable=broad-except
# Check if the cache issue on linux is due
# libsecret not functioning to provider better
# information to the user.
if sys.platform.startswith("linux"):
try:
# pylint: disable=import-outside-toplevel
from msal_extensions.libsecret import trial_run
trial_run()
except Exception as libsecret_ex: # pylint: disable=broad-except
_LOGGER.warning(
"libsecret dependencies are not installed or are unusable.\n"
"Please install the necessary dependencies as instructed in "
"https://github.com/AzureAD/microsoft-authentication-extensions-for-python/wiki/Encryption-on-Linux" # pylint: disable=line-too-long
"Exception:\n%s",
libsecret_ex,
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
)

_LOGGER.warning(
'Error trying to access Azure.Identity Token Cache. '
"Raised unexpected exception:\n%s",
ex,
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
)
return None

def _initialize_credentials(self) -> None:
self._discover_tenant_id_(
arm_endpoint=self.arm_endpoint,
subscription_id=self.subscription_id)
cache_options = self._get_cache_options()
credentials = []
credentials.append(_TokenFileCredential())
credentials.append(EnvironmentCredential())
if self.client_id:
credentials.append(ManagedIdentityCredential(client_id=self.client_id))
if self.authority and self.tenant_id:
credentials.append(VisualStudioCodeCredential(authority=self.authority, tenant_id=self.tenant_id))
credentials.append(VisualStudioCodeCredential(
authority=self.authority,
tenant_id=self.tenant_id))
credentials.append(AzureCliCredential(tenant_id=self.tenant_id))
credentials.append(AzurePowerShellCredential(tenant_id=self.tenant_id))
credentials.append(InteractiveBrowserCredential(authority=self.authority, tenant_id=self.tenant_id))
# The SharedTokenCacheCredential is used when the token cache
# is available to attempt loading a token stored in the cache
# by the InteractiveBrowserCredential.
if cache_options:
credentials.append(SharedTokenCacheCredential(
authority=self.authority,
Copy link
Contributor

@kikomiss kikomiss Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might also need to pass tenant_id=self.tenant_id to here. It will allow the filtering part to pick-up accounts only the discovered tenant

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

@ArthurKamalov ArthurKamalov Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned in this comment #603 (comment), on Linux I have a mismatch of a current tenant id and id from the cache, which leads to exception in the SharedTokenCacheCredential. Still trying to find the cause, still not sure if that's only on my machine or not.

tenant_id=self.tenant_id,
cache_persistence_options=cache_options))
credentials.append(
InteractiveBrowserCredential(
authority=self.authority,
tenant_id=self.tenant_id,
cache_persistence_options=cache_options))
if self.client_id:
credentials.append(DeviceCodeCredential(authority=self.authority, client_id=self.client_id, tenant_id=self.tenant_id))
credentials.append(DeviceCodeCredential(
authority=self.authority,
client_id=self.client_id,
tenant_id=self.tenant_id))
self.credentials = credentials

def get_token(self, *scopes: str, **kwargs) -> AccessToken:
Expand Down Expand Up @@ -145,8 +217,7 @@ def _discover_tenant_id_(self, arm_endpoint:str, subscription_id:str):
match = re.search(WWW_AUTHENTICATE_REGEX, www_authenticate)
if match:
self.tenant_id = match.group("tenant_id")
# pylint: disable=broad-exception-caught
except Exception as ex:
except Exception as ex: # pylint: disable=broad-exception-caught
_LOGGER.error(ex)

# apply default values
Expand Down
Loading