Skip to content

Commit

Permalink
chore: Update pyoxigraph
Browse files Browse the repository at this point in the history
  • Loading branch information
recalcitrantsupplant committed Nov 1, 2024
1 parent 54da1f9 commit e2b17b4
Show file tree
Hide file tree
Showing 7 changed files with 654 additions and 627 deletions.
1,183 changes: 586 additions & 597 deletions poetry.lock

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions prez/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
PrefixNotFoundException,
NoEndpointNodeshapeException
)
from prez.middleware import create_validate_header_middleware
from prez.middleware import ValidateHeaderMiddleware
from prez.repositories import RemoteSparqlRepo, PyoxigraphRepo, OxrdflibRepo
from prez.routers.base_router import router as base_prez_router
from prez.routers.custom_endpoints import create_dynamic_router
Expand Down Expand Up @@ -220,10 +220,11 @@ def assemble_app(
allow_headers=["*"],
expose_headers=["*"],
)
validate_header_middleware = create_validate_header_middleware(
settings.required_header
)
app.middleware("http")(validate_header_middleware)
# validate_header_middleware = create_validate_header_middleware(
# settings.required_header
# )
# app.middleware("http")(validate_header_middleware)
app.add_middleware(ValidateHeaderMiddleware, required_header=settings.required_header)

return app

Expand Down
4 changes: 1 addition & 3 deletions prez/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ class Settings(BaseSettings):
temporal_predicate: Optional[URIRef] = SDO.temporal
endpoint_to_template_query_filename: Optional[Dict[str, str]] = {}
prez_ui_url: Optional[str] = None
required_header: dict[str, str] | None = Field(
default=None, description="Format: {'header_name': 'expected_value'}"
)
required_header: dict[str, str] | None = None

@field_validator("prez_version")
@classmethod
Expand Down
55 changes: 41 additions & 14 deletions prez/middleware.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,50 @@
# from fastapi import Request
# from fastapi.responses import JSONResponse
#
#
# def create_validate_header_middleware(required_header: dict[str, str] | None):
# async def validate_header(request: Request, call_next):
# if required_header:
# header_name, expected_value = next(iter(required_header.items()))
# if (
# header_name not in request.headers
# or request.headers[header_name] != expected_value
# ):
# return JSONResponse( # attempted to use Exception and although it was caught it did not propagate
# status_code=400,
# content={
# "error": "Header Validation Error",
# "message": f"Missing or invalid header: {header_name}",
# "code": "HEADER_VALIDATION_ERROR",
# },
# )
# return await call_next(request)
#
# return validate_header


from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp


def create_validate_header_middleware(required_header: dict[str, str] | None):
async def validate_header(request: Request, call_next):
if required_header:
header_name, expected_value = next(iter(required_header.items()))
if (
header_name not in request.headers
or request.headers[header_name] != expected_value
):
return JSONResponse( # attempted to use Exception and although it was caught it did not propagate
class ValidateHeaderMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, required_header: dict[str, str] | None):
super().__init__(app)
self.required_header = required_header

async def dispatch(self, request: Request, call_next):
if self.required_header:
header_name, expected_value = next(iter(self.required_header.items()))
if header_name not in request.headers or request.headers[header_name] != expected_value:
return JSONResponse(
status_code=400,
content={
"error": "Header Validation Error",
"message": f"Missing or invalid header: {header_name}",
"code": "HEADER_VALIDATION_ERROR",
},
"code": "HEADER_VALIDATION_ERROR"
}
)
return await call_next(request)

return validate_header
response = await call_next(request)
return response
4 changes: 2 additions & 2 deletions prez/repositories/pyoxigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def _sparql(self, query: str) -> dict | Graph | bool:
elif isinstance(results, pyoxigraph.QueryTriples): # a CONSTRUCT query result
result_graph = self._handle_query_triples_results(results)
return result_graph
elif isinstance(results, bool):
results_dict = {"head": {}, "boolean": results}
elif isinstance(results, pyoxigraph.QueryBoolean):
results_dict = {"head": {}, "boolean": bool(results)}
return results_dict
else:
raise TypeError(f"Unexpected result class {type(results)}")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ aiocache = "^0.12.2"
sparql-grammar-pydantic = "^0.1.2"
rdf2geojson = {git = "https://github.com/ashleysommer/rdf2geojson.git", rev = "v0.2.1"}
python-multipart = "^0.0.9"
pyoxigraph = "^0.3.22"
oxrdflib = "^0.3.7"
pyoxigraph = "^0.4.2"
oxrdflib = {git = "https://github.com/oxigraph/oxrdflib.git", rev = "main"}

[tool.poetry.extras]
server = ["uvicorn"]
Expand Down
20 changes: 16 additions & 4 deletions tests/test_sparql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def test_select(client):
"""check that a valid select query returns a 200 response."""
r = client.get(
Expand All @@ -14,11 +17,20 @@ def test_construct(client):
assert r.status_code == 200


def test_ask(client):
"""check that a valid ask query returns a 200 response."""
r = client.get(
"/sparql?query=PREFIX%20ex%3A%20%3Chttp%3A%2F%2Fexample.com%2Fdatasets%2F%3E%0APREFIX%20dcterms%3A%20%3Chttp%3A%2F%2Fpurl.org%2Fdc%2Fterms%2F%3E%0A%0AASK%0AWHERE%20%7B%0A%20%20%3Fsubject%20dcterms%3Atitle%20%3Ftitle%20.%0A%20%20FILTER%20CONTAINS(LCASE(%3Ftitle)%2C%20%22sandgate%22)%0A%7D"
@pytest.mark.parametrize("query,expected_result", [
(
"/sparql?query=PREFIX%20ex%3A%20%3Chttp%3A%2F%2Fexample.com%2Fdatasets%2F%3E%0APREFIX%20dcterms%3A%20%3Chttp%3A%2F%2Fpurl.org%2Fdc%2Fterms%2F%3E%0A%0AASK%0AWHERE%20%7B%0A%20%20%3Fsubject%20dcterms%3Atitle%20%3Ftitle%20.%0A%20%20FILTER%20CONTAINS(LCASE(%3Ftitle)%2C%20%22sandgate%22)%0A%7D",
True
),
(
"/sparql?query=ASK%20%7B%20%3Chttps%3A%2F%2Ffake%3E%20%3Fp%20%3Fo%20%7D",
False
)
])
def test_ask(client, query, expected_result):
"""Check that valid ASK queries return a 200 response with the expected boolean result."""
r = client.get(query)

assert r.status_code == 200


Expand Down

0 comments on commit e2b17b4

Please sign in to comment.