diff --git a/abandonauth/routers/ui.py b/abandonauth/routers/ui.py index dfce08a..40e8804 100644 --- a/abandonauth/routers/ui.py +++ b/abandonauth/routers/ui.py @@ -1,11 +1,12 @@ from typing import Annotated +from uuid import UUID import httpx from fastapi import APIRouter, Form, Request, HTTPException from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from prisma.models import DeveloperApplication -from starlette.status import HTTP_403_FORBIDDEN, HTTP_400_BAD_REQUEST, HTTP_303_SEE_OTHER +from starlette.status import HTTP_403_FORBIDDEN, HTTP_400_BAD_REQUEST, HTTP_303_SEE_OTHER, HTTP_404_NOT_FOUND from abandonauth import templates # type: ignore @@ -319,7 +320,7 @@ async def edit_dev_application_callback_uris( @router.get("/login", response_class=HTMLResponse, include_in_schema=False) -async def oauth_login(request: Request, application_id: str | None = None, callback_uri: str | None = None): +async def oauth_login(request: Request, application_id: UUID | None = None, callback_uri: str | None = None): """Login for initiating the OAuth flow This page is used to start the OAuth flow for applications using AbandonAuth. @@ -335,13 +336,18 @@ async def oauth_login(request: Request, application_id: str | None = None, callb ) else: dev_app = await DeveloperApplication.prisma().find_unique( - where={"id": application_id}, + where={"id": str(application_id)}, include={"callback_uris": True} ) # This check is a convenience in order to provide accurate and immediate feedback to users # The security check for application IDs and callback URIs must be done later in the auth flow - if not dev_app or not dev_app.callback_uris or callback_uri not in [x.uri for x in dev_app.callback_uris]: + if not dev_app: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail="Invalid application ID", + ) + if not dev_app.callback_uris or callback_uri not in [x.uri for x in dev_app.callback_uris]: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid application ID or callback_uri",