diff --git a/syncmaster/server/providers/auth/base_provider.py b/syncmaster/server/providers/auth/base_provider.py index 1f5c10d8..1a57e2c2 100644 --- a/syncmaster/server/providers/auth/base_provider.py +++ b/syncmaster/server/providers/auth/base_provider.py @@ -52,7 +52,7 @@ def __init__( ... @abstractmethod - async def get_current_user(self, access_token: Any, *args, **kwargs) -> User: + async def get_current_user(self, access_token: str | None, **kwargs) -> User: """ This method should return currently logged in user. diff --git a/syncmaster/server/providers/auth/dummy_provider.py b/syncmaster/server/providers/auth/dummy_provider.py index 52abb676..0a7c0ca0 100644 --- a/syncmaster/server/providers/auth/dummy_provider.py +++ b/syncmaster/server/providers/auth/dummy_provider.py @@ -37,7 +37,7 @@ def setup(cls, app: FastAPI) -> FastAPI: app.dependency_overrides[DummyAuthProviderSettings] = lambda: settings return app - async def get_current_user(self, access_token: str, *args, **kwargs) -> User: + async def get_current_user(self, access_token: str | None, **kwargs) -> User: if not access_token: raise AuthorizationError("Missing auth credentials") diff --git a/syncmaster/server/providers/auth/keycloak_provider.py b/syncmaster/server/providers/auth/keycloak_provider.py index 603b7039..0a90ba97 100644 --- a/syncmaster/server/providers/auth/keycloak_provider.py +++ b/syncmaster/server/providers/auth/keycloak_provider.py @@ -6,6 +6,7 @@ from fastapi import Depends, FastAPI, Request from keycloak import KeycloakOpenID +from syncmaster.db.models import User from syncmaster.exceptions import EntityNotFoundError from syncmaster.exceptions.auth import AuthorizationError from syncmaster.exceptions.redirect import RedirectException @@ -63,7 +64,7 @@ async def get_token_authorization_code_grant( ) -> dict[str, Any]: try: redirect_uri = redirect_uri or self.settings.keycloak.redirect_uri - token = self.keycloak_openid.token( + token = await self.keycloak_openid.a_token( grant_type="authorization_code", code=code, redirect_uri=redirect_uri, @@ -72,10 +73,8 @@ async def get_token_authorization_code_grant( except Exception as e: raise AuthorizationError("Failed to get token") from e - async def get_current_user(self, access_token: str, *args, **kwargs) -> Any: + async def get_current_user(self, access_token: str | None, **kwargs) -> User: request: Request = kwargs["request"] - refresh_token = request.session.get("refresh_token") - if not access_token: log.debug("No access token found in session.") self.redirect_to_auth(request.url.path) @@ -86,8 +85,9 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any: token_info = self.keycloak_openid.decode_token(token=access_token) except Exception as e: log.info("Access token is invalid or expired: %s", e) - token_info = None + token_info = {} + refresh_token = request.session.get("refresh_token") if not token_info and refresh_token: log.debug("Access token invalid. Attempting to refresh.") @@ -99,9 +99,7 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any: request.session["access_token"] = new_access_token request.session["refresh_token"] = new_refresh_token - token_info = self.keycloak_openid.decode_token( - token=new_access_token, - ) + token_info = self.keycloak_openid.decode_token(token=new_access_token) log.debug("Access token refreshed and decoded successfully.") except Exception as e: log.debug("Failed to refresh access token: %s", e) @@ -110,19 +108,19 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any: # these names are hardcoded in keycloak: # https://github.com/keycloak/keycloak/blob/3ca3a4ad349b4d457f6829eaf2ae05f1e01408be/core/src/main/java/org/keycloak/representations/IDToken.java user_id = token_info.get("sub") + if not user_id: + raise AuthorizationError("Invalid token payload") + login = token_info.get("preferred_username") email = token_info.get("email") first_name = token_info.get("given_name") middle_name = token_info.get("middle_name") last_name = token_info.get("family_name") - if not user_id: - raise AuthorizationError("Invalid token payload") - - async with self._uow: - try: - user = await self._uow.user.read_by_username(login) - except EntityNotFoundError: + try: + user = await self._uow.user.read_by_username(login) + except EntityNotFoundError: + async with self._uow: user = await self._uow.user.create( username=login, email=email, @@ -134,7 +132,7 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any: return user async def refresh_access_token(self, refresh_token: str) -> dict[str, Any]: - new_tokens = self.keycloak_openid.refresh_token(refresh_token) + new_tokens = await self.keycloak_openid.a_refresh_token(refresh_token) return new_tokens def redirect_to_auth(self, path: str) -> None: diff --git a/syncmaster/server/settings/auth/base.py b/syncmaster/server/settings/auth/base.py new file mode 100644 index 00000000..c004f608 --- /dev/null +++ b/syncmaster/server/settings/auth/base.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 + +from pydantic import BaseModel, Field, ImportString + + +class AuthSettings(BaseModel): + """Authorization-related settings. + + Here you can set auth provider class. + + Examples + -------- + + .. code-block:: bash + + SYNCMASTER__AUTH__PROVIDER=syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider + """ + + provider: ImportString = Field( # type: ignore[assignment] + default="syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider", + description="Full name of auth provider class", + validate_default=True, + ) + + class Config: + extra = "allow"