Skip to content

Commit

Permalink
Setup Openid connect
Browse files Browse the repository at this point in the history
  • Loading branch information
berrydenhartog committed Oct 10, 2024
1 parent 131c8dd commit 0f51081
Show file tree
Hide file tree
Showing 13 changed files with 350 additions and 19 deletions.
6 changes: 5 additions & 1 deletion amt/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from amt.api.http_browser_caching import url_for_cache
from amt.api.navigation import NavigationItem, get_main_menu
from amt.core.authorization import get_user
from amt.core.config import VERSION, get_settings
from amt.core.internationalization import (
format_datetime,
Expand All @@ -26,7 +27,9 @@
logger = logging.getLogger(__name__)


def custom_context_processor(request: Request) -> dict[str, str | list[str] | dict[str, str] | list[NavigationItem]]:
def custom_context_processor(
request: Request,
) -> dict[str, str | None | list[str] | dict[str, str] | list[NavigationItem]]:
lang = get_requested_language(request)
translations = get_current_translation(request)
return {
Expand All @@ -35,6 +38,7 @@ def custom_context_processor(request: Request) -> dict[str, str | list[str] | di
"language": lang,
"translations": get_dynamic_field_translations(lang),
"main_menu_items": get_main_menu(request, translations),
"user": get_user(request),
}


Expand Down
3 changes: 2 additions & 1 deletion amt/api/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fastapi import APIRouter

from amt.api.routes import health, pages, project, projects, root
from amt.api.routes import auth, health, pages, project, projects, root

api_router = APIRouter()
api_router.include_router(root.router)
api_router.include_router(health.router, prefix="/health", tags=["health"])
api_router.include_router(pages.router, prefix="/pages", tags=["pages"])
api_router.include_router(projects.router, prefix="/projects", tags=["projects"])
api_router.include_router(project.router, prefix="/project", tags=["projects"])
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
56 changes: 56 additions & 0 deletions amt/api/routes/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging

from authlib.integrations.starlette_client import OAuth, OAuthError # type: ignore
from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse, RedirectResponse, Response

from amt.core.authorization import get_user, no_authorization
from amt.core.exceptions import AMTAuthorizationError

router = APIRouter()
logger = logging.getLogger(__name__)


@router.get("/login")
@no_authorization
async def login(request: Request) -> HTMLResponse:
oauth: OAuth = request.app.state.oauth
redirect_uri = request.url_for("auth_callback")
return await oauth.keycloak.authorize_redirect(request, redirect_uri) # type: ignore


@router.get("/logout")
@no_authorization
async def logout(request: Request) -> RedirectResponse:
user = get_user(request)
id_token = request.session.get("id_token", None)
request.session.pop("user", None)
request.session.pop("id_token", None)

if user:
redirect_uri = request.url_for("base")
return RedirectResponse(
url=user["iss"]
+ "/protocol/openid-connect/logout?id_token_hint="
+ id_token
+ "&post_logout_redirect_uri="
+ str(redirect_uri)
)

return RedirectResponse(url="/")


@router.get("/callback", response_class=Response)
@no_authorization
async def auth_callback(request: Request) -> Response:
oauth: OAuth = request.app.state.oauth
try:
token = await oauth.keycloak.authorize_access_token(request) # type: ignore
except OAuthError as error:
raise AMTAuthorizationError() from error

user: dict = token.get("userinfo") # type: ignore
if user:
request.session["user"] = dict(user) # type: ignore
request.session["id_token"] = token["id_token"] # type: ignore
return RedirectResponse(url="/")
2 changes: 2 additions & 0 deletions amt/api/routes/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from fastapi.responses import HTMLResponse

from amt.api.deps import templates
from amt.core.authorization import no_authorization

router = APIRouter()

logger = logging.getLogger(__name__)


@router.get("/")
@no_authorization
async def base(request: Request) -> HTMLResponse:
breadcrumbs = {}

Expand Down
23 changes: 23 additions & 0 deletions amt/core/authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from collections.abc import Callable
from functools import wraps
from typing import Any

from starlette.requests import Request


def get_user(request: Request) -> dict[str, str] | None:
user = None
if "session" in request.scope:
user = request.session.get("user", None)
return user


def no_authorization(func: Callable[..., Any]): # noqa: ANN201
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
request: Request = kwargs.get("request") # type: ignore
if request:
request.state.noauth = True
return await func(*args, **kwargs)

return wrapper
5 changes: 5 additions & 0 deletions amt/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class Settings(BaseSettings):
DEBUG: bool = False
AUTO_CREATE_SCHEMA: bool = False

ISSUER: str = "https://keycloak.apps.digilab.network/realms/algoritmes"
OIDC_CLIENT_ID: str | None = None
OIDC_CLIENT_SECRET: str | None = None
OIDC_DISCOVERY_URL: str = "https://keycloak.apps.digilab.network/realms/algoritmes/.well-known/openid-configuration"

CARD_DIR: Path = Path("/tmp/") # TODO(berry): create better location for model cards # noqa: S108

# todo(berry): create submodel for database settings
Expand Down
6 changes: 6 additions & 0 deletions amt/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@ class AMTValueError(AMTHTTPException):
def __init__(self, field: str) -> None:
self.detail: str = _("Value not correct: {field}").format(field=field)
super().__init__(status.HTTP_400_BAD_REQUEST, self.detail)


class AMTAuthorizationError(AMTHTTPException):
def __init__(self) -> None:
self.detail: str = _("Failed to authorize, please login and try again.")
super().__init__(status.HTTP_401_UNAUTHORIZED, self.detail)
31 changes: 31 additions & 0 deletions amt/middleware/authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import typing

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response

from amt.core.authorization import get_user
from amt.core.exceptions import AMTAuthorizationError

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]


class AuthorizationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
if request.url.path.startswith("/static/"):
return await call_next(request)

user = get_user(request)
if user:
return await call_next(request)

response = await call_next(request)
## todo: move to decorator function
if hasattr(request.state, "noauth"):
return response

if os.environ.get("PYTEST_CURRENT_TEST", False):
return response

raise AMTAuthorizationError()
16 changes: 16 additions & 0 deletions amt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager

from authlib.integrations.starlette_client import OAuth # type: ignore
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import HTMLResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.sessions import SessionMiddleware

from amt.api.main import api_router
from amt.core.config import PROJECT_DESCRIPTION, PROJECT_NAME, VERSION, get_settings
Expand All @@ -15,6 +17,7 @@
from amt.utils.mask import Mask

from .api.http_browser_caching import static_files
from .middleware.authorization import AuthorizationMiddleware
from .middleware.csrf import CSRFMiddleware, CSRFMiddlewareExceptionHandler
from .middleware.htmx import HTMXMiddleware
from .middleware.route_logging import RequestLoggingMiddleware
Expand Down Expand Up @@ -49,12 +52,25 @@ def create_app() -> FastAPI:
debug=get_settings().DEBUG,
)

app.add_middleware(AuthorizationMiddleware)
app.add_middleware(SessionMiddleware, secret_key=get_settings().SECRET_KEY)
app.add_middleware(RequestLoggingMiddleware)
app.add_middleware(CSRFMiddleware)
app.add_middleware(CSRFMiddlewareExceptionHandler)
app.add_middleware(HTMXMiddleware)
app.add_middleware(SecurityMiddleware)

oauth = OAuth()
app.state.oauth = oauth

oauth.register( # type: ignore
name="keycloak",
client_id=get_settings().OIDC_CLIENT_ID,
client_secret=get_settings().OIDC_CLIENT_SECRET,
server_metadata_url=get_settings().OIDC_DISCOVERY_URL,
client_kwargs={"scope": "openid profile email"},
)

app.mount("/static", static_files, name="static")

@app.exception_handler(StarletteHTTPException)
Expand Down
47 changes: 40 additions & 7 deletions amt/site/templates/parts/header.html.j2
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,39 @@

{% endfor %}

<li class="rvo-topnav__item">
{% for available_translation in available_translations %}
<li class="rvo-topnav__item">
<a class="rvo-link rvo-topnav__link rvo-link--logoblauw {% if available_translation == language %} selected {% endif %}"
{% if available_translation != language %}

<a class="rvo-link rvo-topnav__link rvo-link--logoblauw"
id="mobile-langselect-{{ available_translation }}"
href="javascript:setCookie('lang', '{{ available_translation }}', 9999); window.location.reload()"
>
<span
class="utrecht-icon rvo-icon rvo-icon-wereldbol rvo-icon--sm rvo-icon--wit rvo-link__icon--before"
class="utrecht-icon rvo-icon rvo-icon-wereldbol rvo-icon--sm rvo-link__icon--before"
role="img"
aria-label="{% trans %}Language{% endtrans %} {{ available_translation }}"
></span>
{{ available_translation }}
</a>
</li>
{% endfor %}
{% endif %}
{% endfor %}
</li>
{% if not user %}
<li class="rvo-topnav__item">
<a class="rvo-link rvo-topnav__link rvo-link--logoblauw" href="/auth/login">
<span class="utrecht-icon rvo-icon rvo-icon-user rvo-icon--sm rvo-link__icon--before" role="img"
aria-label="User"></span> {%trans %}Login{%endtrans %}
</a>
</li>
{% else %}
<li class="rvo-topnav__item">
<a class="rvo-link rvo-topnav__link rvo-link--logoblauw" href="/auth/logout">
<span class="utrecht-icon rvo-icon rvo-icon-user rvo-icon--sm rvo-link__icon--before" role="img"
aria-label="User"></span> {%trans %}Logout{%endtrans %}
</a>
</li>
{% endif %}
</ul>
</nav>
</div>
Expand All @@ -97,7 +115,10 @@

<li class="rvo-topnav__item rvo-topnav__item--align-right" style="display: inline-flex">
{% for available_translation in available_translations %}
<a class="rvo-link rvo-topnav__link rvo-link--logoblauw {% if available_translation == language %} selected {% endif %}"
{% if available_translation != language %}


<a class="rvo-link rvo-topnav__link rvo-link--logoblauw"
id="langselect-{{ available_translation }}"
href="javascript:setCookie('lang', '{{ available_translation }}', 9999); window.location.reload()"
>
Expand All @@ -108,9 +129,21 @@
></span>
{{ available_translation }}
</a>
{% endif %}
{% endfor %}
{% if not user %}
<a class="rvo-link rvo-topnav__link rvo-link--logoblauw" href="/auth/login">
<span class="utrecht-icon rvo-icon rvo-icon-user rvo-icon--md rvo-icon--wit" role="img" aria-label="User"></span>
{% trans %}Login{%endtrans %}
</a>
{% else %}
<a class="rvo-link rvo-topnav__link rvo-link--logoblauw" href="/auth/logout">
<span class="utrecht-icon rvo-icon rvo-icon-versleutelen rvo-icon--md rvo-icon--wit" role="img"
aria-label="User"></span>
{% trans %}Logout{%endtrans %}
</a>
{% endif %}
</li>

</ul>
</nav>
</div>
Expand Down
Loading

0 comments on commit 0f51081

Please sign in to comment.