Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

David/check header #294

Merged
merged 3 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,183 changes: 586 additions & 597 deletions poetry.lock

Large diffs are not rendered by default.

21 changes: 16 additions & 5 deletions prez/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from pathlib import Path
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path
from textwrap import dedent
from typing import Optional, Dict, Union, Any

Expand All @@ -28,14 +28,16 @@
URINotFoundException,
NoProfilesException,
InvalidSPARQLQueryException,
PrefixNotFoundException, NoEndpointNodeshapeException,
PrefixNotFoundException,
NoEndpointNodeshapeException
)
from prez.middleware import create_validate_header_middleware
from prez.repositories import RemoteSparqlRepo, PyoxigraphRepo, OxrdflibRepo
from prez.routers.base_router import router as base_prez_router
from prez.routers.custom_endpoints import create_dynamic_router
from prez.routers.identifier import router as identifier_router
from prez.routers.management import router as management_router, config_router
from prez.routers.ogc_features_router import features_subapi
from prez.routers.base_router import router as base_prez_router
from prez.routers.sparql import router as sparql_router
from prez.services.app_service import (
healthcheck_sparql_endpoints,
Expand All @@ -54,7 +56,8 @@
catch_uri_not_found_exception,
catch_no_profiles_exception,
catch_invalid_sparql_query,
catch_prefix_not_found_exception, catch_no_endpoint_nodeshape_exception,
catch_prefix_not_found_exception,
catch_no_endpoint_nodeshape_exception,
)
from prez.services.generate_profiles import create_profiles_graph
from prez.services.prez_logging import setup_logger
Expand Down Expand Up @@ -186,7 +189,11 @@ def assemble_app(
app.include_router(sparql_router)
if _settings.configuration_mode:
app.include_router(config_router)
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
app.mount(
"/static",
StaticFiles(directory=Path(__file__).parent / "static"),
name="static",
)
if _settings.enable_ogc_features:
app.mount(
_settings.ogc_features_mount_path,
Expand All @@ -213,6 +220,10 @@ def assemble_app(
allow_headers=["*"],
expose_headers=["*"],
)
validate_header_middleware = create_validate_header_middleware(
settings.required_header
)
app.middleware("http")(validate_header_middleware)

return app

Expand Down
7 changes: 5 additions & 2 deletions prez/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import toml
from pydantic import field_validator
from pydantic import field_validator, Field
from pydantic_settings import BaseSettings
from rdflib import DCTERMS, RDFS, SDO, URIRef
from rdflib.namespace import SKOS
Expand Down Expand Up @@ -81,12 +81,15 @@ class Settings(BaseSettings):
]
enable_sparql_endpoint: bool = False
enable_ogc_features: bool = True
ogc_features_mount_path: str = "/catalogs/{catalogId}/collections/{recordsCollectionId}/features"
ogc_features_mount_path: str = (
"/catalogs/{catalogId}/collections/{recordsCollectionId}/features"
)
custom_endpoints: bool = False
configuration_mode: bool = False
temporal_predicate: Optional[URIRef] = SDO.temporal
endpoint_to_template_query_filename: Optional[Dict[str, str]] = {}
prez_ui_url: Optional[str] = None
required_header: dict[str, str] | None = None

@field_validator("prez_version")
@classmethod
Expand Down
23 changes: 23 additions & 0 deletions prez/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fastapi import Request
from fastapi.responses import JSONResponse


def create_validate_header_middleware(required_header: dict[str, str] | None):
async def validate_header(request: Request, call_next):
if required_header:
header_name, expected_value = next(iter(required_header.items()))
if (
header_name not in request.headers
or request.headers[header_name] != expected_value
):
return JSONResponse( # attempted to use Exception and although it was caught it did not propagate
status_code=400,
content={
"error": "Header Validation Error",
"message": f"Missing or invalid header: {header_name}",
"code": "HEADER_VALIDATION_ERROR",
},
)
return await call_next(request)

return validate_header
4 changes: 2 additions & 2 deletions prez/repositories/pyoxigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def _sparql(self, query: str) -> dict | Graph | bool:
elif isinstance(results, pyoxigraph.QueryTriples): # a CONSTRUCT query result
result_graph = self._handle_query_triples_results(results)
return result_graph
elif isinstance(results, bool):
results_dict = {"head": {}, "boolean": results}
elif isinstance(results, pyoxigraph.QueryBoolean):
results_dict = {"head": {}, "boolean": bool(results)}
return results_dict
else:
raise TypeError(f"Unexpected result class {type(results)}")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ aiocache = "^0.12.2"
sparql-grammar-pydantic = "^0.1.2"
rdf2geojson = {git = "https://github.com/ashleysommer/rdf2geojson.git", rev = "v0.2.1"}
python-multipart = "^0.0.9"
pyoxigraph = "^0.3.22"
oxrdflib = "^0.3.7"
pyoxigraph = "^0.4.2"
oxrdflib = {git = "https://github.com/oxigraph/oxrdflib.git", rev = "main"}

[tool.poetry.extras]
server = ["uvicorn"]
Expand Down
20 changes: 16 additions & 4 deletions tests/test_sparql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def test_select(client):
"""check that a valid select query returns a 200 response."""
r = client.get(
Expand All @@ -14,11 +17,20 @@ def test_construct(client):
assert r.status_code == 200


def test_ask(client):
"""check that a valid ask query returns a 200 response."""
r = client.get(
"/sparql?query=PREFIX%20ex%3A%20%3Chttp%3A%2F%2Fexample.com%2Fdatasets%2F%3E%0APREFIX%20dcterms%3A%20%3Chttp%3A%2F%2Fpurl.org%2Fdc%2Fterms%2F%3E%0A%0AASK%0AWHERE%20%7B%0A%20%20%3Fsubject%20dcterms%3Atitle%20%3Ftitle%20.%0A%20%20FILTER%20CONTAINS(LCASE(%3Ftitle)%2C%20%22sandgate%22)%0A%7D"
@pytest.mark.parametrize("query,expected_result", [
(
"/sparql?query=PREFIX%20ex%3A%20%3Chttp%3A%2F%2Fexample.com%2Fdatasets%2F%3E%0APREFIX%20dcterms%3A%20%3Chttp%3A%2F%2Fpurl.org%2Fdc%2Fterms%2F%3E%0A%0AASK%0AWHERE%20%7B%0A%20%20%3Fsubject%20dcterms%3Atitle%20%3Ftitle%20.%0A%20%20FILTER%20CONTAINS(LCASE(%3Ftitle)%2C%20%22sandgate%22)%0A%7D",
True
),
(
"/sparql?query=ASK%20%7B%20%3Chttps%3A%2F%2Ffake%3E%20%3Fp%20%3Fo%20%7D",
False
)
])
def test_ask(client, query, expected_result):
"""Check that valid ASK queries return a 200 response with the expected boolean result."""
r = client.get(query)

assert r.status_code == 200


Expand Down
Loading