From 54da1f96d364d0b42f8f9b26a299d99184951071 Mon Sep 17 00:00:00 2001 From: recalcitrantsupplant Date: Fri, 1 Nov 2024 10:48:35 +1000 Subject: [PATCH] feat: Add ability for prez to check for a header on all requests such as an API key. --- prez/app.py | 21 ++++++++++++++++----- prez/config.py | 9 +++++++-- prez/middleware.py | 23 +++++++++++++++++++++++ 3 files changed, 46 insertions(+), 7 deletions(-) create mode 100644 prez/middleware.py diff --git a/prez/app.py b/prez/app.py index aa80373d..63eb6bf5 100755 --- a/prez/app.py +++ b/prez/app.py @@ -1,7 +1,7 @@ import logging -from pathlib import Path from contextlib import asynccontextmanager from functools import partial +from pathlib import Path from textwrap import dedent from typing import Optional, Dict, Union, Any @@ -28,14 +28,16 @@ URINotFoundException, NoProfilesException, InvalidSPARQLQueryException, - PrefixNotFoundException, NoEndpointNodeshapeException, + PrefixNotFoundException, + NoEndpointNodeshapeException ) +from prez.middleware import create_validate_header_middleware from prez.repositories import RemoteSparqlRepo, PyoxigraphRepo, OxrdflibRepo +from prez.routers.base_router import router as base_prez_router from prez.routers.custom_endpoints import create_dynamic_router from prez.routers.identifier import router as identifier_router from prez.routers.management import router as management_router, config_router from prez.routers.ogc_features_router import features_subapi -from prez.routers.base_router import router as base_prez_router from prez.routers.sparql import router as sparql_router from prez.services.app_service import ( healthcheck_sparql_endpoints, @@ -54,7 +56,8 @@ catch_uri_not_found_exception, catch_no_profiles_exception, catch_invalid_sparql_query, - catch_prefix_not_found_exception, catch_no_endpoint_nodeshape_exception, + catch_prefix_not_found_exception, + catch_no_endpoint_nodeshape_exception, ) from prez.services.generate_profiles import create_profiles_graph from prez.services.prez_logging import setup_logger @@ -186,7 +189,11 @@ def assemble_app( app.include_router(sparql_router) if _settings.configuration_mode: app.include_router(config_router) - app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static") + app.mount( + "/static", + StaticFiles(directory=Path(__file__).parent / "static"), + name="static", + ) if _settings.enable_ogc_features: app.mount( _settings.ogc_features_mount_path, @@ -213,6 +220,10 @@ def assemble_app( allow_headers=["*"], expose_headers=["*"], ) + validate_header_middleware = create_validate_header_middleware( + settings.required_header + ) + app.middleware("http")(validate_header_middleware) return app diff --git a/prez/config.py b/prez/config.py index d1a9a8dd..afd3b7bd 100755 --- a/prez/config.py +++ b/prez/config.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import toml -from pydantic import field_validator +from pydantic import field_validator, Field from pydantic_settings import BaseSettings from rdflib import DCTERMS, RDFS, SDO, URIRef from rdflib.namespace import SKOS @@ -81,12 +81,17 @@ class Settings(BaseSettings): ] enable_sparql_endpoint: bool = False enable_ogc_features: bool = True - ogc_features_mount_path: str = "/catalogs/{catalogId}/collections/{recordsCollectionId}/features" + ogc_features_mount_path: str = ( + "/catalogs/{catalogId}/collections/{recordsCollectionId}/features" + ) custom_endpoints: bool = False configuration_mode: bool = False temporal_predicate: Optional[URIRef] = SDO.temporal endpoint_to_template_query_filename: Optional[Dict[str, str]] = {} prez_ui_url: Optional[str] = None + required_header: dict[str, str] | None = Field( + default=None, description="Format: {'header_name': 'expected_value'}" + ) @field_validator("prez_version") @classmethod diff --git a/prez/middleware.py b/prez/middleware.py new file mode 100644 index 00000000..4f483933 --- /dev/null +++ b/prez/middleware.py @@ -0,0 +1,23 @@ +from fastapi import Request +from fastapi.responses import JSONResponse + + +def create_validate_header_middleware(required_header: dict[str, str] | None): + async def validate_header(request: Request, call_next): + if required_header: + header_name, expected_value = next(iter(required_header.items())) + if ( + header_name not in request.headers + or request.headers[header_name] != expected_value + ): + return JSONResponse( # attempted to use Exception and although it was caught it did not propagate + status_code=400, + content={ + "error": "Header Validation Error", + "message": f"Missing or invalid header: {header_name}", + "code": "HEADER_VALIDATION_ERROR", + }, + ) + return await call_next(request) + + return validate_header