Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add type hints to credentials #1605

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
7 changes: 4 additions & 3 deletions google/auth/_credentials_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
"""Interface for base credentials."""

import abc
from typing import Optional

from google.auth import _helpers


class _BaseCredentials(metaclass=abc.ABCMeta):
class BaseCredentials(metaclass=abc.ABCMeta):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to change this to a public class?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_BaseCredentials is imported from outside the _credentials_base.py file, so it's not private in that scope. You can see that it's captured as an error by running:

pyright google/auth/credentials.py

"""Base class for all credentials.

All credentials have a :attr:`token` that is used for authentication and
Expand All @@ -44,7 +45,7 @@ class _BaseCredentials(metaclass=abc.ABCMeta):
"""

def __init__(self):
self.token = None
self.token: Optional[str] = None

@abc.abstractmethod
def refresh(self, request):
Expand All @@ -62,7 +63,7 @@ def refresh(self, request):
# (pylint doesn't recognize that this is abstract)
raise NotImplementedError("Refresh must be implemented")

def _apply(self, headers, token=None):
def _apply(self, headers: dict[str, str], token: Optional[str] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use Mapping from typing (here and in other places)? Also update the docstring to Mapping[str, str].

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation requires the type of headers has __setitem__ method that's why being a Mapping is not enough:

headers["authorization"] = "Bearer {}".format(
    _helpers.from_bytes(token or self.token)
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want me to modify the doctring instead?

"""Apply the token to the authentication header.

Args:
Expand Down
8 changes: 5 additions & 3 deletions google/auth/_refresh_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import logging
import threading

from google.auth.credentials import Credentials
import google.auth.exceptions as e
from google.auth.transport import Request

_LOGGER = logging.getLogger(__name__)

Expand All @@ -32,7 +34,7 @@ def __init__(self):
self._worker = None
self._lock = threading.Lock() # protects access to worker threads.

def start_refresh(self, cred, request):
def start_refresh(self, cred: Credentials, request: Request):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def start_refresh(self, cred: Credentials, request: Request):
def start_refresh(self, cred: Credentials, request: Request): -> bool

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Starts a refresh thread for the given credentials.
The credentials are refreshed using the request parameter.
request and cred MUST not be None
Expand Down Expand Up @@ -61,8 +63,8 @@ def start_refresh(self, cred, request):

def clear_error(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def clear_error(self):
def clear_error(self): -> None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Removes any errors that were stored from previous background refreshes.
"""
Removes any errors that were stored from previous background refreshes.
"""
with self._lock:
if self._worker:
self._worker._error_info = None
Expand Down
5 changes: 2 additions & 3 deletions google/auth/aio/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@

"""Interfaces for asynchronous credentials."""


from google.auth import _helpers
from google.auth import exceptions
from google.auth._credentials_base import _BaseCredentials
from google.auth._credentials_base import BaseCredentials


class Credentials(_BaseCredentials):
class Credentials(BaseCredentials):
"""Base class for all asynchronous credentials.

All credentials have a :attr:`token` that is used for authentication and
Expand Down
74 changes: 45 additions & 29 deletions google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@
import abc
from enum import Enum
import os
from typing import Optional, Self, Sequence

from google.auth import _helpers, environment_vars
from google.auth import exceptions
from google.auth import metrics
from google.auth._credentials_base import _BaseCredentials
from google.auth._credentials_base import BaseCredentials
from google.auth._refresh_worker import RefreshThreadManager
from google.auth.credentials import Credentials
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're importing from the file itself? This doesn't seem right.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from google.auth.credentials import Credentials

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yeah, sorry you're right

from google.auth.crypt import Signer
from google.auth.transport import Request

DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"


class Credentials(_BaseCredentials):
class Credentials(BaseCredentials):
"""Base class for all credentials.

All credentials have a :attr:`token` that is used for authentication and
Expand Down Expand Up @@ -67,7 +71,7 @@ def __init__(self):
self._refresh_worker = RefreshThreadManager()

@property
def expired(self):
def expired(self) -> bool:
"""Checks if the credentials are expired.

Note that credentials can be invalid but not expired because
Expand All @@ -85,7 +89,7 @@ def expired(self):
return _helpers.utcnow() >= skewed_expiry

@property
def valid(self):
def valid(self) -> bool:
"""Checks the validity of the credentials.

This is True if the credentials have a :attr:`token` and the token
Expand Down Expand Up @@ -140,7 +144,7 @@ def get_cred_info(self):
return None

@abc.abstractmethod
def refresh(self, request):
def refresh(self, request: Request) -> None:
"""Refreshes the access token.

Args:
Expand Down Expand Up @@ -170,7 +174,7 @@ def _metric_header_for_usage(self):
"""
return None

def apply(self, headers, token=None):
def apply(self, headers: dict[str, str], token: Optional[str] = None):
"""Apply the token to the authentication header.

Args:
Expand All @@ -197,11 +201,11 @@ def apply(self, headers, token=None):
if self.quota_project_id:
headers["x-goog-user-project"] = self.quota_project_id

def _blocking_refresh(self, request):
def _blocking_refresh(self, request: Request):
if not self.valid:
self.refresh(request)

def _non_blocking_refresh(self, request):
def _non_blocking_refresh(self, request: Request):
use_blocking_refresh_fallback = False

if self.token_state == TokenState.STALE:
Expand All @@ -216,7 +220,9 @@ def _non_blocking_refresh(self, request):
# background thread.
self._refresh_worker.clear_error()

def before_request(self, request, method, url, headers):
def before_request(
self, request: Request, method: str, url: str, headers: dict[str, str]
):
"""Performs credential-specific before request logic.

Refreshes the credentials if necessary, then calls :meth:`apply` to
Expand Down Expand Up @@ -248,7 +254,7 @@ def with_non_blocking_refresh(self):
class CredentialsWithQuotaProject(Credentials):
"""Abstract base for credentials supporting ``with_quota_project`` factory"""

def with_quota_project(self, quota_project_id):
def with_quota_project(self, quota_project_id: str) -> Self:
"""Returns a copy of these credentials with a modified quota project.

Args:
Expand All @@ -260,7 +266,7 @@ def with_quota_project(self, quota_project_id):
"""
raise NotImplementedError("This credential does not support quota project.")

def with_quota_project_from_environment(self):
def with_quota_project_from_environment(self) -> Self:
quota_from_env = os.environ.get(environment_vars.GOOGLE_CLOUD_QUOTA_PROJECT)
if quota_from_env:
return self.with_quota_project(quota_from_env)
Expand All @@ -270,7 +276,7 @@ def with_quota_project_from_environment(self):
class CredentialsWithTokenUri(Credentials):
"""Abstract base for credentials supporting ``with_token_uri`` factory"""

def with_token_uri(self, token_uri):
def with_token_uri(self, token_uri: str) -> Self:
"""Returns a copy of these credentials with a modified token uri.

Args:
Expand All @@ -285,7 +291,7 @@ def with_token_uri(self, token_uri):
class CredentialsWithUniverseDomain(Credentials):
"""Abstract base for credentials supporting ``with_universe_domain`` factory"""

def with_universe_domain(self, universe_domain):
def with_universe_domain(self, universe_domain: str) -> Self:
"""Returns a copy of these credentials with a modified universe domain.

Args:
Expand All @@ -307,21 +313,21 @@ class AnonymousCredentials(Credentials):
"""

@property
def expired(self):
def expired(self) -> bool:
"""Returns `False`, anonymous credentials never expire."""
return False

@property
def valid(self):
def valid(self) -> bool:
"""Returns `True`, anonymous credentials are always valid."""
return True

def refresh(self, request):
def refresh(self, request: Request):
"""Raises :class:``InvalidOperation``, anonymous credentials cannot be
refreshed."""
raise exceptions.InvalidOperation("Anonymous credentials cannot be refreshed.")

def apply(self, headers, token=None):
def apply(self, headers: dict[str, str], token: Optional[str] = None):
"""Anonymous credentials do nothing to the request.

The optional ``token`` argument is not supported.
Expand All @@ -332,7 +338,9 @@ def apply(self, headers, token=None):
if token is not None:
raise exceptions.InvalidValue("Anonymous credentials don't support tokens.")

def before_request(self, request, method, url, headers):
def before_request(
self, request: Request, method: str, url: str, headers: dict[str, str]
):
"""Anonymous credentials do nothing to the request."""


Expand Down Expand Up @@ -380,13 +388,13 @@ def default_scopes(self):
"""Sequence[str]: the credentials' current set of default scopes."""
return self._default_scopes

@abc.abstractproperty
@property
@abc.abstractmethod
def requires_scopes(self):
"""True if these credentials require scopes to obtain an access token.
"""
"""True if these credentials require scopes to obtain an access token."""
return False

def has_scopes(self, scopes):
def has_scopes(self, scopes: Sequence[str]) -> bool:
"""Checks if the credentials have the given scopes.

.. warning: This method is not guaranteed to be accurate if the
Expand Down Expand Up @@ -434,7 +442,9 @@ class Scoped(ReadOnlyScoped):
"""

@abc.abstractmethod
def with_scopes(self, scopes, default_scopes=None):
def with_scopes(
self, scopes: Sequence[str], default_scopes: Optional[Sequence[str]] = None
) -> Self:
"""Create a copy of these credentials with the specified scopes.

Args:
Expand All @@ -449,7 +459,11 @@ def with_scopes(self, scopes, default_scopes=None):
raise NotImplementedError("This class does not require scoping.")


def with_scopes_if_required(credentials, scopes, default_scopes=None):
def with_scopes_if_required(
credentials: Credentials,
scopes: Sequence[str],
default_scopes: Optional[Sequence[str]] = None,
) -> Credentials:
"""Creates a copy of the credentials with scopes if scoping is required.

This helper function is useful when you do not know (or care to know) the
Expand Down Expand Up @@ -481,7 +495,7 @@ class Signing(metaclass=abc.ABCMeta):
"""Interface for credentials that can cryptographically sign messages."""

@abc.abstractmethod
def sign_bytes(self, message):
def sign_bytes(self, message: bytes) -> bytes:
"""Signs the given message.

Args:
Expand All @@ -494,15 +508,17 @@ def sign_bytes(self, message):
# (pylint doesn't recognize that this is abstract)
raise NotImplementedError("Sign bytes must be implemented.")

@abc.abstractproperty
def signer_email(self):
@property
@abc.abstractmethod
def signer_email(self) -> Optional[str]:
"""Optional[str]: An email address that identifies the signer."""
# pylint: disable=missing-raises-doc
# (pylint doesn't recognize that this is abstract)
raise NotImplementedError("Signer email must be implemented.")

@abc.abstractproperty
def signer(self):
@property
@abc.abstractmethod
def signer(self) -> Signer:
"""google.auth.crypt.Signer: The signer used to sign bytes."""
# pylint: disable=missing-raises-doc
# (pylint doesn't recognize that this is abstract)
Expand Down
7 changes: 5 additions & 2 deletions google/auth/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""" We use x-goog-api-client header to report metrics. This module provides
"""We use x-goog-api-client header to report metrics. This module provides
the constants and helper methods to construct x-goog-api-client header.
"""

import platform
from typing import Mapping, Optional

from google.auth import version

Expand Down Expand Up @@ -48,6 +49,7 @@ def python_and_auth_lib_version():

# Token request metric header values


# x-goog-api-client header value for access token request via metadata server.
# Example: "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/mds"
def token_request_access_token_mds():
Expand Down Expand Up @@ -108,6 +110,7 @@ def token_request_user():

# Miscellenous metrics


# x-goog-api-client header value for metadata server ping.
# Example: "gl-python/3.7 auth/1.1 auth-request-type/mds"
def mds_ping():
Expand Down Expand Up @@ -135,7 +138,7 @@ def byoid_metrics_header(metrics_options):
return header


def add_metric_header(headers, metric_header_value):
def add_metric_header(headers: Mapping[str, str], metric_header_value: Optional[str]):
"""Add x-goog-api-client header with the given value.

Args:
Expand Down