Skip to content

Commit

Permalink
Ensure for SPARQL passthrough queries headers are passed through from…
Browse files Browse the repository at this point in the history
… the original request and not the defaults supplied by httpx AsyncClient (e.g. compression). (#295)

Return more sensible error messages.
  • Loading branch information
recalcitrantsupplant authored Nov 5, 2024
1 parent 8b9a374 commit 2ce1b3f
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions prez/repositories/remote_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,45 @@ async def tabular_query_to_table(self, query: str, context: URIRef = None):
return context, response.json()["results"]["bindings"]

async def sparql(
self, query: str, raw_headers: list[tuple[bytes, bytes]], method: str = "GET"
self, query: str, raw_headers: list[tuple[bytes, bytes]], method: str = "GET"
):
"""Sends a starlette Request object (containing a SPARQL query in the URL parameters) to a proxied SPARQL
endpoint."""
"""Sends a request (containing a SPARQL query in the URL parameters) to a proxied SPARQL endpoint."""
# Convert raw_headers to a dict, excluding the 'host' header
headers = {k.decode('utf-8'): v.decode('utf-8') for k, v in raw_headers if k.lower() != b'host'}

headers = []
for header in raw_headers:
if header[0] != b"host":
headers.append(header)

# TODO: Global app settings should be passed in as a function argument.
if method == 'GET':
query_escaped_as_bytes = f"query={quote_plus(query)}".encode("utf-8")
url = httpx.URL(url=settings.sparql_endpoint, query=query_escaped_as_bytes)
rp_req = self.async_client.build_request(method, url, headers=headers)
query_escaped = quote_plus(query)
url = f"{settings.sparql_endpoint}?query={query_escaped}"
request = httpx.Request(method, url, headers=headers)
else:
url = httpx.URL(url=settings.sparql_endpoint)
rp_req = self.async_client.build_request(method, url, headers=headers, data={'query': query})
url = settings.sparql_endpoint
# Prepare form data
form_data = f"query={quote_plus(query)}"

# Set correct headers for form data
headers['content-type'] = 'application/x-www-form-urlencoded'
headers['content-length'] = str(len(form_data))

request = httpx.Request(
method,
url,
headers=headers,
content=form_data.encode('utf-8')
)

# Add the correct 'host' header
request.headers['host'] = httpx.URL(url).host

headers.append((b"host", str(url.host).encode("utf-8")))
response = await self.async_client.send(request, stream=True)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
await response.aread()
print(f"Error content: {response.text}")
raise httpx.HTTPStatusError(
f"HTTP Error {response.status_code}: {response.text}",
request=request,
response=response
) from e

return await self.async_client.send(rp_req, stream=True)
return response

0 comments on commit 2ce1b3f

Please sign in to comment.