Skip to content

Commit

Permalink
[DOP-23122] Use async methods of Keycloak client
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Dec 23, 2024
1 parent 162b0a0 commit ef38031
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 18 deletions.
2 changes: 1 addition & 1 deletion syncmaster/server/providers/auth/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
...

@abstractmethod
async def get_current_user(self, access_token: Any, *args, **kwargs) -> User:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
"""
This method should return currently logged in user.
Expand Down
2 changes: 1 addition & 1 deletion syncmaster/server/providers/auth/dummy_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setup(cls, app: FastAPI) -> FastAPI:
app.dependency_overrides[DummyAuthProviderSettings] = lambda: settings
return app

async def get_current_user(self, access_token: str, *args, **kwargs) -> User:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
if not access_token:
raise AuthorizationError("Missing auth credentials")

Expand Down
30 changes: 14 additions & 16 deletions syncmaster/server/providers/auth/keycloak_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi import Depends, FastAPI, Request
from keycloak import KeycloakOpenID

from syncmaster.db.models import User
from syncmaster.exceptions import EntityNotFoundError
from syncmaster.exceptions.auth import AuthorizationError
from syncmaster.exceptions.redirect import RedirectException
Expand Down Expand Up @@ -63,7 +64,7 @@ async def get_token_authorization_code_grant(
) -> dict[str, Any]:
try:
redirect_uri = redirect_uri or self.settings.keycloak.redirect_uri
token = self.keycloak_openid.token(
token = await self.keycloak_openid.a_token(
grant_type="authorization_code",
code=code,
redirect_uri=redirect_uri,
Expand All @@ -72,10 +73,8 @@ async def get_token_authorization_code_grant(
except Exception as e:
raise AuthorizationError("Failed to get token") from e

async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
request: Request = kwargs["request"]
refresh_token = request.session.get("refresh_token")

if not access_token:
log.debug("No access token found in session.")
self.redirect_to_auth(request.url.path)
Expand All @@ -86,8 +85,9 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
token_info = self.keycloak_openid.decode_token(token=access_token)
except Exception as e:
log.info("Access token is invalid or expired: %s", e)
token_info = None
token_info = {}

refresh_token = request.session.get("refresh_token")
if not token_info and refresh_token:
log.debug("Access token invalid. Attempting to refresh.")

Expand All @@ -99,9 +99,7 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
request.session["access_token"] = new_access_token
request.session["refresh_token"] = new_refresh_token

token_info = self.keycloak_openid.decode_token(
token=new_access_token,
)
token_info = self.keycloak_openid.decode_token(token=new_access_token)
log.debug("Access token refreshed and decoded successfully.")
except Exception as e:
log.debug("Failed to refresh access token: %s", e)
Expand All @@ -110,19 +108,19 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
# these names are hardcoded in keycloak:
# https://github.com/keycloak/keycloak/blob/3ca3a4ad349b4d457f6829eaf2ae05f1e01408be/core/src/main/java/org/keycloak/representations/IDToken.java
user_id = token_info.get("sub")
if not user_id:
raise AuthorizationError("Invalid token payload")

login = token_info.get("preferred_username")
email = token_info.get("email")
first_name = token_info.get("given_name")
middle_name = token_info.get("middle_name")
last_name = token_info.get("family_name")

if not user_id:
raise AuthorizationError("Invalid token payload")

async with self._uow:
try:
user = await self._uow.user.read_by_username(login)
except EntityNotFoundError:
try:
user = await self._uow.user.read_by_username(login)
except EntityNotFoundError:
async with self._uow:
user = await self._uow.user.create(
username=login,
email=email,
Expand All @@ -134,7 +132,7 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
return user

async def refresh_access_token(self, refresh_token: str) -> dict[str, Any]:
new_tokens = self.keycloak_openid.refresh_token(refresh_token)
new_tokens = await self.keycloak_openid.a_refresh_token(refresh_token)
return new_tokens

def redirect_to_auth(self, path: str) -> None:
Expand Down
27 changes: 27 additions & 0 deletions syncmaster/server/settings/auth/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0

from pydantic import BaseModel, Field, ImportString


class AuthSettings(BaseModel):
"""Authorization-related settings.
Here you can set auth provider class.
Examples
--------
.. code-block:: bash
SYNCMASTER__AUTH__PROVIDER=syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider
"""

provider: ImportString = Field( # type: ignore[assignment]
default="syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider",
description="Full name of auth provider class",
validate_default=True,
)

class Config:
extra = "allow"

0 comments on commit ef38031

Please sign in to comment.