diff --git a/prez/app.py b/prez/app.py index fd31523f..63eb6bf5 100755 --- a/prez/app.py +++ b/prez/app.py @@ -31,7 +31,7 @@ PrefixNotFoundException, NoEndpointNodeshapeException ) -from prez.middleware import ValidateHeaderMiddleware +from prez.middleware import create_validate_header_middleware from prez.repositories import RemoteSparqlRepo, PyoxigraphRepo, OxrdflibRepo from prez.routers.base_router import router as base_prez_router from prez.routers.custom_endpoints import create_dynamic_router @@ -220,11 +220,10 @@ def assemble_app( allow_headers=["*"], expose_headers=["*"], ) - # validate_header_middleware = create_validate_header_middleware( - # settings.required_header - # ) - # app.middleware("http")(validate_header_middleware) - app.add_middleware(ValidateHeaderMiddleware, required_header=settings.required_header) + validate_header_middleware = create_validate_header_middleware( + settings.required_header + ) + app.middleware("http")(validate_header_middleware) return app diff --git a/prez/middleware.py b/prez/middleware.py index dcc888e0..4f483933 100644 --- a/prez/middleware.py +++ b/prez/middleware.py @@ -1,50 +1,23 @@ -# from fastapi import Request -# from fastapi.responses import JSONResponse -# -# -# def create_validate_header_middleware(required_header: dict[str, str] | None): -# async def validate_header(request: Request, call_next): -# if required_header: -# header_name, expected_value = next(iter(required_header.items())) -# if ( -# header_name not in request.headers -# or request.headers[header_name] != expected_value -# ): -# return JSONResponse( # attempted to use Exception and although it was caught it did not propagate -# status_code=400, -# content={ -# "error": "Header Validation Error", -# "message": f"Missing or invalid header: {header_name}", -# "code": "HEADER_VALIDATION_ERROR", -# }, -# ) -# return await call_next(request) -# -# return validate_header - - from fastapi import Request from fastapi.responses import JSONResponse -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.types import ASGIApp -class ValidateHeaderMiddleware(BaseHTTPMiddleware): - def __init__(self, app: ASGIApp, required_header: dict[str, str] | None): - super().__init__(app) - self.required_header = required_header - - async def dispatch(self, request: Request, call_next): - if self.required_header: - header_name, expected_value = next(iter(self.required_header.items())) - if header_name not in request.headers or request.headers[header_name] != expected_value: - return JSONResponse( +def create_validate_header_middleware(required_header: dict[str, str] | None): + async def validate_header(request: Request, call_next): + if required_header: + header_name, expected_value = next(iter(required_header.items())) + if ( + header_name not in request.headers + or request.headers[header_name] != expected_value + ): + return JSONResponse( # attempted to use Exception and although it was caught it did not propagate status_code=400, content={ "error": "Header Validation Error", "message": f"Missing or invalid header: {header_name}", - "code": "HEADER_VALIDATION_ERROR" - } + "code": "HEADER_VALIDATION_ERROR", + }, ) - response = await call_next(request) - return response + return await call_next(request) + + return validate_header