Skip to content

Commit

Permalink
accept access token from header
Browse files Browse the repository at this point in the history
  • Loading branch information
Qup42 committed Jan 20, 2025
1 parent acb6633 commit d3e8eac
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 27 deletions.
35 changes: 32 additions & 3 deletions src/engine/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,27 @@ void Server::run(const string& indexBaseName, bool useText, bool usePatterns,
httpServer.run();
}

std::optional<std::string> Server::extractAccessToken(
const ad_utility::httpUtils::HttpRequest auto& request,
const ad_utility::url_parser::ParamValueMap& params) {
if (request.find(http::field::authorization) != request.end()) {
string_view authorization = request[http::field::authorization];
const std::string prefix = "Bearer ";
if (!authorization.starts_with(prefix)) {
throw std::runtime_error(absl::StrCat(
"Authorization header must start with \"Bearer \". Got: \"",
authorization, "\"."));
}
authorization.remove_prefix(prefix.length());
return std::string(authorization);
}
if (params.contains("access-token")) {
return ad_utility::url_parser::getParameterCheckAtMostOnce(params,
"access-token");
}
return std::nullopt;
}

// _____________________________________________________________________________
ad_utility::url_parser::ParsedRequest Server::parseHttpRequest(
const ad_utility::httpUtils::HttpRequest auto& request) {
Expand All @@ -172,7 +193,8 @@ ad_utility::url_parser::ParsedRequest Server::parseHttpRequest(
using namespace ad_utility::url_parser::sparqlOperation;
auto parsedUrl = ad_utility::url_parser::parseRequestTarget(request.target());
ad_utility::url_parser::ParsedRequest parsedRequest{
std::move(parsedUrl.path_), std::move(parsedUrl.parameters_), None{}};
std::move(parsedUrl.path_), std::nullopt,
std::move(parsedUrl.parameters_), None{}};

// Some valid requests (e.g. QLever's custom commands like retrieving index
// statistics) don't have a query. So an empty operation is not necessarily an
Expand All @@ -186,9 +208,14 @@ ad_utility::url_parser::ParsedRequest Server::parseHttpRequest(
parsedRequest.parameters_.erase(paramName);
}
};
auto extractAccessTokenFromRequest = [&parsedRequest, &request]() {
parsedRequest.accessToken_ =
extractAccessToken(request, parsedRequest.parameters_);
};

if (request.method() == http::verb::get) {
setOperationIfSpecifiedInParams.template operator()<Query>("query");
extractAccessTokenFromRequest();
if (parsedRequest.parameters_.contains("update")) {
throw std::runtime_error("SPARQL Update is not allowed as GET request.");
}
Expand Down Expand Up @@ -258,15 +285,18 @@ ad_utility::url_parser::ParsedRequest Server::parseHttpRequest(
}
setOperationIfSpecifiedInParams.template operator()<Query>("query");
setOperationIfSpecifiedInParams.template operator()<Update>("update");
extractAccessTokenFromRequest();

return parsedRequest;
}
if (contentType.starts_with(contentTypeSparqlQuery)) {
parsedRequest.operation_ = Query{request.body()};
extractAccessTokenFromRequest();
return parsedRequest;
}
if (contentType.starts_with(contentTypeSparqlUpdate)) {
parsedRequest.operation_ = Update{request.body()};
extractAccessTokenFromRequest();
return parsedRequest;
}
throw std::runtime_error(absl::StrCat(
Expand Down Expand Up @@ -353,8 +383,7 @@ Awaitable<void> Server::process(
// Check the access token. If an access token is provided and the check fails,
// throw an exception and do not process any part of the query (even if the
// processing had been allowed without access token).
bool accessTokenOk =
checkAccessToken(checkParameter("access-token", std::nullopt));
bool accessTokenOk = checkAccessToken(parsedHttpRequest.accessToken_);

Check warning on line 386 in src/engine/Server.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Server.cpp#L386

Added line #L386 was not covered by tests
auto requireValidAccessToken = [&accessTokenOk](
const std::string& actionName) {
if (!accessTokenOk) {
Expand Down
7 changes: 7 additions & 0 deletions src/engine/Server.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Server {
FRIEND_TEST(ServerTest, parseHttpRequest);
FRIEND_TEST(ServerTest, getQueryId);
FRIEND_TEST(ServerTest, createMessageSender);
FRIEND_TEST(ServerTest, extractAccessToken);

public:
explicit Server(unsigned short port, size_t numThreads,
Expand Down Expand Up @@ -114,6 +115,12 @@ class Server {
static ad_utility::url_parser::ParsedRequest parseHttpRequest(
const ad_utility::httpUtils::HttpRequest auto& request);

/// Extract the Access token for that request from the `Authorization` header
/// or the URL query parameters.
static std::optional<std::string> extractAccessToken(
const ad_utility::httpUtils::HttpRequest auto& request,
const ad_utility::url_parser::ParamValueMap& params);

/// Handle a single HTTP request. Check whether a file request or a query was
/// sent, and dispatch to functions handling these cases. This function
/// requires the constraints for the `HttpHandler` in `HttpServer.h`.
Expand Down
2 changes: 2 additions & 0 deletions src/util/http/UrlParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ struct None {

// Representation of parsed HTTP request.
// - `path_` is the URL path
// - `accessToken_` is the access token for that request
// - `parameters_` is a hashmap of the parameters
// - `operation_` the operation that should be performed
struct ParsedRequest {
std::string path_;
std::optional<std::string> accessToken_;
ParamValueMap parameters_;
std::variant<sparqlOperation::Query, sparqlOperation::Update,
sparqlOperation::None>
Expand Down
132 changes: 108 additions & 24 deletions test/ServerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ using namespace ad_utility::url_parser::sparqlOperation;

namespace {
auto ParsedRequestIs = [](const std::string& path,
const std::optional<std::string>& accessToken,
const ParamValueMap& parameters,
const std::variant<Query, Update, None>& operation)
-> testing::Matcher<const ParsedRequest> {
return testing::AllOf(
AD_FIELD(ad_utility::url_parser::ParsedRequest, path_, testing::Eq(path)),
AD_FIELD(ad_utility::url_parser::ParsedRequest, accessToken_,
testing::Eq(accessToken)),
AD_FIELD(ad_utility::url_parser::ParsedRequest, parameters_,
testing::ContainerEq(parameters)),
AD_FIELD(ad_utility::url_parser::ParsedRequest, operation_,
Expand Down Expand Up @@ -55,19 +58,21 @@ TEST(ServerTest, parseHttpRequest) {
"application/x-www-form-urlencoded;charset=UTF-8";
const std::string QUERY = "application/sparql-query";
const std::string UPDATE = "application/sparql-update";
EXPECT_THAT(parse(MakeGetRequest("/")), ParsedRequestIs("/", {}, None{}));
EXPECT_THAT(parse(MakeGetRequest("/")),
ParsedRequestIs("/", std::nullopt, {}, None{}));
EXPECT_THAT(parse(MakeGetRequest("/ping")),
ParsedRequestIs("/ping", {}, None{}));
ParsedRequestIs("/ping", std::nullopt, {}, None{}));
EXPECT_THAT(parse(MakeGetRequest("/?cmd=stats")),
ParsedRequestIs("/", {{"cmd", {"stats"}}}, None{}));
ParsedRequestIs("/", std::nullopt, {{"cmd", {"stats"}}}, None{}));
EXPECT_THAT(parse(MakeGetRequest(
"/?query=SELECT+%2A%20WHERE%20%7B%7D&action=csv_export")),
ParsedRequestIs("/", {{"action", {"csv_export"}}},
ParsedRequestIs("/", std::nullopt, {{"action", {"csv_export"}}},
Query{"SELECT * WHERE {}"}));
EXPECT_THAT(
parse(MakePostRequest("/", URLENCODED,
"query=SELECT+%2A%20WHERE%20%7B%7D&send=100")),
ParsedRequestIs("/", {{"send", {"100"}}}, Query{"SELECT * WHERE {}"}));
ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}},
Query{"SELECT * WHERE {}"}));
AD_EXPECT_THROW_WITH_MESSAGE(
parse(MakePostRequest("/", URLENCODED,
"ääär y=SELECT+%2A%20WHERE%20%7B%7D&send=100")),
Expand All @@ -92,18 +97,20 @@ TEST(ServerTest, parseHttpRequest) {
EXPECT_THAT(
parse(MakePostRequest("/", "application/x-www-form-urlencoded",
"query=SELECT%20%2A%20WHERE%20%7B%7D&send=100")),
ParsedRequestIs("/", {{"send", {"100"}}}, Query{"SELECT * WHERE {}"}));
EXPECT_THAT(parse(MakePostRequest("/", URLENCODED,
"query=SELECT%20%2A%20WHERE%20%7B%7D")),
ParsedRequestIs("/", {}, Query{"SELECT * WHERE {}"}));
ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}},
Query{"SELECT * WHERE {}"}));
EXPECT_THAT(
parse(MakePostRequest("/", URLENCODED,
"query=SELECT%20%2A%20WHERE%20%7B%7D")),
ParsedRequestIs("/", std::nullopt, {}, Query{"SELECT * WHERE {}"}));
EXPECT_THAT(
parse(MakePostRequest(
"/", URLENCODED,
"query=SELECT%20%2A%20WHERE%20%7B%7D&default-graph-uri=https%3A%2F%"
"2Fw3.org%2Fdefault&named-graph-uri=https%3A%2F%2Fw3.org%2F1&named-"
"graph-uri=https%3A%2F%2Fw3.org%2F2")),
ParsedRequestIs(
"/",
"/", std::nullopt,
{{"default-graph-uri", {"https://w3.org/default"}},
{"named-graph-uri", {"https://w3.org/1", "https://w3.org/2"}}},
Query{"SELECT * WHERE {}"}));
Expand All @@ -112,13 +119,15 @@ TEST(ServerTest, parseHttpRequest) {
"query=SELECT%20%2A%20WHERE%20%7B%7D")),
testing::StrEq("URL-encoded POST requests must not contain query "
"parameters in the URL."));
EXPECT_THAT(parse(MakePostRequest("/", URLENCODED, "cmd=clear-cache")),
ParsedRequestIs("/", {{"cmd", {"clear-cache"}}}, None{}));
EXPECT_THAT(parse(MakePostRequest("/", QUERY, "SELECT * WHERE {}")),
ParsedRequestIs("/", {}, Query{"SELECT * WHERE {}"}));
EXPECT_THAT(
parse(MakePostRequest("/?send=100", QUERY, "SELECT * WHERE {}")),
ParsedRequestIs("/", {{"send", {"100"}}}, Query{"SELECT * WHERE {}"}));
parse(MakePostRequest("/", URLENCODED, "cmd=clear-cache")),
ParsedRequestIs("/", std::nullopt, {{"cmd", {"clear-cache"}}}, None{}));
EXPECT_THAT(
parse(MakePostRequest("/", QUERY, "SELECT * WHERE {}")),
ParsedRequestIs("/", std::nullopt, {}, Query{"SELECT * WHERE {}"}));
EXPECT_THAT(parse(MakePostRequest("/?send=100", QUERY, "SELECT * WHERE {}")),
ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}},
Query{"SELECT * WHERE {}"}));
AD_EXPECT_THROW_WITH_MESSAGE(
parse(MakeBasicRequest(http::verb::patch, "/")),
testing::StrEq(
Expand All @@ -132,14 +141,35 @@ TEST(ServerTest, parseHttpRequest) {
AD_EXPECT_THROW_WITH_MESSAGE(
parse(MakeGetRequest("/?update=DELETE%20%2A%20WHERE%20%7B%7D")),
testing::StrEq("SPARQL Update is not allowed as GET request."));
EXPECT_THAT(parse(MakePostRequest("/", UPDATE, "DELETE * WHERE {}")),
ParsedRequestIs("/", {}, Update{"DELETE * WHERE {}"}));
EXPECT_THAT(parse(MakePostRequest("/", URLENCODED,
"update=DELETE%20%2A%20WHERE%20%7B%7D")),
ParsedRequestIs("/", {}, Update{"DELETE * WHERE {}"}));
EXPECT_THAT(parse(MakePostRequest("/", URLENCODED,
"update=DELETE+%2A+WHERE%20%7B%7D")),
ParsedRequestIs("/", {}, Update{"DELETE * WHERE {}"}));
EXPECT_THAT(
parse(MakePostRequest("/", UPDATE, "DELETE * WHERE {}")),
ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}"}));
EXPECT_THAT(
parse(MakePostRequest("/", URLENCODED,
"update=DELETE%20%2A%20WHERE%20%7B%7D")),
ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}"}));
EXPECT_THAT(
parse(
MakePostRequest("/", URLENCODED, "update=DELETE+%2A+WHERE%20%7B%7D")),
ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}"}));
// TODO<qup42>: there could be some more here, but i'll wait until #1668
EXPECT_THAT(
parse(MakeGetRequest("/?query=a&access-token=foo")),
ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, Query{"a"}));
EXPECT_THAT(
parse(MakePostRequest("/", URLENCODED,
"update=DELETE%20WHERE%20%7B%7D&access-token=foo")),
ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}},
Update{"DELETE WHERE {}"}));
EXPECT_THAT(parse(MakePostRequest(
"/", URLENCODED,
"query=SELECT%20%2A%20WHERE%20%7B%7D&access-token=foo")),
ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}},
Query{"SELECT * WHERE {}"}));
EXPECT_THAT(
parse(MakePostRequest("/?access-token=foo", UPDATE, "DELETE * WHERE {}")),
ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}},
Update{"DELETE * WHERE {}"}));
}

TEST(ServerTest, checkParameter) {
Expand Down Expand Up @@ -284,3 +314,57 @@ TEST(ServerTest, createMessageSender) {
"SELECT * WHERE { ?a ?b ?c }"),
testing::HasSubstr("Assertion `queryHubLock` failed."));
}

TEST(ServerTest, extractAccessToken) {
auto extract = [](const ad_utility::httpUtils::HttpRequest auto& request) {
auto parsedUrl = parseRequestTarget(request.target());
return Server::extractAccessToken(request, parsedUrl.parameters_);
};
// TODO<qup42>: replace once #1668 is merged
auto makeRequest =
[](const http::verb method = http::verb::get,
const std::string& target = "/",
const ad_utility::HashMap<http::field, std::string>& headers = {},
const std::optional<std::string>& body = std::nullopt) {
// version 11 stands for HTTP/1.1
auto req = http::request<http::string_body>{method, target, 11};
for (const auto& [key, value] : headers) {
req.set(key, value);
}
if (body.has_value()) {
req.body() = body.value();
req.prepare_payload();
}
return req;
};
EXPECT_THAT(extract(MakeGetRequest("/")), testing::Eq(std::nullopt));
EXPECT_THAT(extract(MakeGetRequest("/?access-token=foo")),
testing::Optional(testing::Eq("foo")));
EXPECT_THAT(
extract(makeRequest(http::verb::get, "/",
{{http::field::authorization, "Bearer foo"}})),
testing::Optional(testing::Eq("foo")));
// The header takes precedence over the query parameter.
EXPECT_THAT(
extract(makeRequest(http::verb::get, "/?access-token=bar",
{{http::field::authorization, "Bearer foo"}})),
testing::Optional(testing::Eq("foo")));
AD_EXPECT_THROW_WITH_MESSAGE(
extract(makeRequest(http::verb::get, "/",
{{http::field::authorization, "foo"}})),
testing::HasSubstr(
"Authorization header must start with \"Bearer \". Got: \"foo\"."));
EXPECT_THAT(extract(MakePostRequest("/", "text/turtle", "")),
testing::Eq(std::nullopt));
EXPECT_THAT(extract(MakePostRequest("/?access-token=foo", "text/turtle", "")),
testing::Optional(testing::Eq("foo")));
EXPECT_THAT(
extract(makeRequest(http::verb::post, "/?access-token=bar",
{{http::field::authorization, "Bearer foo"}})),
testing::Optional(testing::Eq("foo")));
AD_EXPECT_THROW_WITH_MESSAGE(
extract(makeRequest(http::verb::post, "/?access-token=bar",
{{http::field::authorization, "foo"}})),
testing::HasSubstr(
"Authorization header must start with \"Bearer \". Got: \"foo\"."));
}

0 comments on commit d3e8eac

Please sign in to comment.