Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOP-23122] Use async methods of Keycloak client #177

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ onetl = {extras = ["spark", "s3", "hdfs"], version = "^0.12.0"}
faker = ">=28.4.1,<34.0.0"
coverage = "^7.6.1"
gevent = "^24.2.1"
responses = "*"
respx = "*"

[tool.poetry.group.dev.dependencies]
mypy = "^1.11.2"
Expand Down Expand Up @@ -196,6 +196,10 @@ ignore_missing_imports = true
module = "keycloak.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "jwcrypto.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "avro.*"
ignore_missing_imports = true
Expand Down
2 changes: 1 addition & 1 deletion syncmaster/server/providers/auth/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
...

@abstractmethod
async def get_current_user(self, access_token: Any, *args, **kwargs) -> User:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
"""
This method should return currently logged in user.

Expand Down
2 changes: 1 addition & 1 deletion syncmaster/server/providers/auth/dummy_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setup(cls, app: FastAPI) -> FastAPI:
app.dependency_overrides[DummyAuthProviderSettings] = lambda: settings
return app

async def get_current_user(self, access_token: str, *args, **kwargs) -> User:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
if not access_token:
raise AuthorizationError("Missing auth credentials")

Expand Down
58 changes: 26 additions & 32 deletions syncmaster/server/providers/auth/keycloak_provider.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Annotated, Any
from typing import Annotated, Any, NoReturn

from fastapi import Depends, FastAPI, Request
from keycloak import KeycloakOpenID
from jwcrypto.common import JWException
from keycloak import KeycloakOpenID, KeycloakOperationError

from syncmaster.db.models import User
from syncmaster.exceptions import EntityNotFoundError
from syncmaster.exceptions.auth import AuthorizationError
from syncmaster.exceptions.redirect import RedirectException
Expand Down Expand Up @@ -63,83 +65,75 @@ async def get_token_authorization_code_grant(
) -> dict[str, Any]:
try:
redirect_uri = redirect_uri or self.settings.keycloak.redirect_uri
token = self.keycloak_openid.token(
token = await self.keycloak_openid.a_token(
grant_type="authorization_code",
code=code,
redirect_uri=redirect_uri,
)
return token
except Exception as e:
except KeycloakOperationError as e:
raise AuthorizationError("Failed to get token") from e

async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
request: Request = kwargs["request"]
refresh_token = request.session.get("refresh_token")

if not access_token:
log.debug("No access token found in session.")
self.redirect_to_auth(request.url.path)
await self.redirect_to_auth(request.url.path)

try:
# if user is disabled or blocked in Keycloak after the token is issued, he will
# remain authorized until the token expires (not more than 15 minutes in MTS SSO)
token_info = self.keycloak_openid.decode_token(token=access_token)
except Exception as e:
token_info = await self.keycloak_openid.a_decode_token(token=access_token)
except (KeycloakOperationError, JWException) as e:
log.info("Access token is invalid or expired: %s", e)
token_info = None
token_info = {}

refresh_token = request.session.get("refresh_token")
if not token_info and refresh_token:
log.debug("Access token invalid. Attempting to refresh.")

try:
new_tokens = await self.refresh_access_token(refresh_token)
new_tokens = await self.keycloak_openid.a_refresh_token(refresh_token)

new_access_token = new_tokens.get("access_token")
new_refresh_token = new_tokens.get("refresh_token")
request.session["access_token"] = new_access_token
request.session["refresh_token"] = new_refresh_token

token_info = self.keycloak_openid.decode_token(
token=new_access_token,
)
token_info = await self.keycloak_openid.a_decode_token(token=new_access_token)
log.debug("Access token refreshed and decoded successfully.")
except Exception as e:
except (KeycloakOperationError, JWException) as e:
log.debug("Failed to refresh access token: %s", e)
self.redirect_to_auth(request.url.path)
await self.redirect_to_auth(request.url.path)

# these names are hardcoded in keycloak:
# https://github.com/keycloak/keycloak/blob/3ca3a4ad349b4d457f6829eaf2ae05f1e01408be/core/src/main/java/org/keycloak/representations/IDToken.java
user_id = token_info.get("sub")
if not user_id:
raise AuthorizationError("Invalid token payload")

login = token_info.get("preferred_username")
email = token_info.get("email")
first_name = token_info.get("given_name")
middle_name = token_info.get("middle_name")
last_name = token_info.get("family_name")

if not user_id:
raise AuthorizationError("Invalid token payload")

async with self._uow:
try:
user = await self._uow.user.read_by_username(login)
except EntityNotFoundError:
user = await self._uow.user.create(
try:
return await self._uow.user.read_by_username(login)
except EntityNotFoundError:
async with self._uow:
return await self._uow.user.create(
username=login,
email=email,
first_name=first_name,
middle_name=middle_name,
last_name=last_name,
is_active=True,
)
return user

async def refresh_access_token(self, refresh_token: str) -> dict[str, Any]:
new_tokens = self.keycloak_openid.refresh_token(refresh_token)
return new_tokens

def redirect_to_auth(self, path: str) -> None:
async def redirect_to_auth(self, path: str) -> NoReturn:
state = generate_state(path)
auth_url = self.keycloak_openid.auth_url(
auth_url = await self.keycloak_openid.a_auth_url(
redirect_uri=self.settings.keycloak.redirect_uri,
scope=self.settings.keycloak.scope,
state=state,
Expand Down
27 changes: 27 additions & 0 deletions syncmaster/server/settings/auth/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0

from pydantic import BaseModel, Field, ImportString


class AuthSettings(BaseModel):
"""Authorization-related settings.

Here you can set auth provider class.

Examples
--------

.. code-block:: bash

SYNCMASTER__AUTH__PROVIDER=syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider
"""

provider: ImportString = Field( # type: ignore[assignment]
default="syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider",
description="Full name of auth provider class",
validate_default=True,
)

class Config:
extra = "allow"
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,27 +131,27 @@ async def session(sessionmaker: async_sessionmaker[AsyncSession]):
await session.close()


@pytest.fixture(scope="session")
@pytest.fixture
def mocked_celery() -> Celery:
celery_app = Mock(Celery)
celery_app.send_task = AsyncMock()
return celery_app


@pytest_asyncio.fixture(scope="session")
@pytest_asyncio.fixture
async def app(settings: Settings, mocked_celery: Celery) -> FastAPI:
app = application_factory(settings=settings)
app.dependency_overrides[Celery] = lambda: mocked_celery
return app


@pytest_asyncio.fixture(scope="session")
@pytest_asyncio.fixture
async def client_with_mocked_celery(app: FastAPI) -> AsyncGenerator:
async with AsyncClient(app=app, base_url="http://testserver") as client:
yield client


@pytest_asyncio.fixture(scope="session")
@pytest_asyncio.fixture
async def client(settings: Settings) -> AsyncGenerator:
logger.info("START CLIENT FIXTURE")
app = application_factory(settings=settings)
Expand All @@ -160,7 +160,7 @@ async def client(settings: Settings) -> AsyncGenerator:
logger.info("END CLIENT FIXTURE")


@pytest.fixture(scope="session", params=[{}])
@pytest.fixture
def celery(worker_settings: WorkerAppSettings) -> Celery:
celery_app = celery_factory(worker_settings)
return celery_app
Expand Down
1 change: 1 addition & 0 deletions tests/test_unit/test_auth/auth_fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tests.test_unit.test_auth.auth_fixtures.keycloak_fixture import (
create_session_cookie,
mock_keycloak_api,
mock_keycloak_realm,
mock_keycloak_token_refresh,
mock_keycloak_well_known,
Expand Down
Loading
Loading