Skip to content

Commit

Permalink
feat: Add ability for prez to check for a header on all requests such…
Browse files Browse the repository at this point in the history
… as an API key.
  • Loading branch information
recalcitrantsupplant committed Nov 1, 2024
1 parent 2a17d2c commit 54da1f9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
21 changes: 16 additions & 5 deletions prez/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from pathlib import Path
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path
from textwrap import dedent
from typing import Optional, Dict, Union, Any

Expand All @@ -28,14 +28,16 @@
URINotFoundException,
NoProfilesException,
InvalidSPARQLQueryException,
PrefixNotFoundException, NoEndpointNodeshapeException,
PrefixNotFoundException,
NoEndpointNodeshapeException
)
from prez.middleware import create_validate_header_middleware
from prez.repositories import RemoteSparqlRepo, PyoxigraphRepo, OxrdflibRepo
from prez.routers.base_router import router as base_prez_router
from prez.routers.custom_endpoints import create_dynamic_router
from prez.routers.identifier import router as identifier_router
from prez.routers.management import router as management_router, config_router
from prez.routers.ogc_features_router import features_subapi
from prez.routers.base_router import router as base_prez_router
from prez.routers.sparql import router as sparql_router
from prez.services.app_service import (
healthcheck_sparql_endpoints,
Expand All @@ -54,7 +56,8 @@
catch_uri_not_found_exception,
catch_no_profiles_exception,
catch_invalid_sparql_query,
catch_prefix_not_found_exception, catch_no_endpoint_nodeshape_exception,
catch_prefix_not_found_exception,
catch_no_endpoint_nodeshape_exception,
)
from prez.services.generate_profiles import create_profiles_graph
from prez.services.prez_logging import setup_logger
Expand Down Expand Up @@ -186,7 +189,11 @@ def assemble_app(
app.include_router(sparql_router)
if _settings.configuration_mode:
app.include_router(config_router)
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
app.mount(
"/static",
StaticFiles(directory=Path(__file__).parent / "static"),
name="static",
)
if _settings.enable_ogc_features:
app.mount(
_settings.ogc_features_mount_path,
Expand All @@ -213,6 +220,10 @@ def assemble_app(
allow_headers=["*"],
expose_headers=["*"],
)
validate_header_middleware = create_validate_header_middleware(
settings.required_header
)
app.middleware("http")(validate_header_middleware)

return app

Expand Down
9 changes: 7 additions & 2 deletions prez/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import toml
from pydantic import field_validator
from pydantic import field_validator, Field
from pydantic_settings import BaseSettings
from rdflib import DCTERMS, RDFS, SDO, URIRef
from rdflib.namespace import SKOS
Expand Down Expand Up @@ -81,12 +81,17 @@ class Settings(BaseSettings):
]
enable_sparql_endpoint: bool = False
enable_ogc_features: bool = True
ogc_features_mount_path: str = "/catalogs/{catalogId}/collections/{recordsCollectionId}/features"
ogc_features_mount_path: str = (
"/catalogs/{catalogId}/collections/{recordsCollectionId}/features"
)
custom_endpoints: bool = False
configuration_mode: bool = False
temporal_predicate: Optional[URIRef] = SDO.temporal
endpoint_to_template_query_filename: Optional[Dict[str, str]] = {}
prez_ui_url: Optional[str] = None
required_header: dict[str, str] | None = Field(
default=None, description="Format: {'header_name': 'expected_value'}"
)

@field_validator("prez_version")
@classmethod
Expand Down
23 changes: 23 additions & 0 deletions prez/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fastapi import Request
from fastapi.responses import JSONResponse


def create_validate_header_middleware(required_header: dict[str, str] | None):
async def validate_header(request: Request, call_next):
if required_header:
header_name, expected_value = next(iter(required_header.items()))
if (
header_name not in request.headers
or request.headers[header_name] != expected_value
):
return JSONResponse( # attempted to use Exception and although it was caught it did not propagate
status_code=400,
content={
"error": "Header Validation Error",
"message": f"Missing or invalid header: {header_name}",
"code": "HEADER_VALIDATION_ERROR",
},
)
return await call_next(request)

return validate_header

0 comments on commit 54da1f9

Please sign in to comment.