From d3e8eac4c6119c5f27ad9bbf59cbb22e5371ff56 Mon Sep 17 00:00:00 2001 From: Julian Mundhahs Date: Mon, 20 Jan 2025 13:53:46 +0100 Subject: [PATCH] accept access token from header --- src/engine/Server.cpp | 35 +++++++++- src/engine/Server.h | 7 ++ src/util/http/UrlParser.h | 2 + test/ServerTest.cpp | 132 +++++++++++++++++++++++++++++++------- 4 files changed, 149 insertions(+), 27 deletions(-) diff --git a/src/engine/Server.cpp b/src/engine/Server.cpp index 08fc6f9607..5614572b32 100644 --- a/src/engine/Server.cpp +++ b/src/engine/Server.cpp @@ -164,6 +164,27 @@ void Server::run(const string& indexBaseName, bool useText, bool usePatterns, httpServer.run(); } +std::optional Server::extractAccessToken( + const ad_utility::httpUtils::HttpRequest auto& request, + const ad_utility::url_parser::ParamValueMap& params) { + if (request.find(http::field::authorization) != request.end()) { + string_view authorization = request[http::field::authorization]; + const std::string prefix = "Bearer "; + if (!authorization.starts_with(prefix)) { + throw std::runtime_error(absl::StrCat( + "Authorization header must start with \"Bearer \". Got: \"", + authorization, "\".")); + } + authorization.remove_prefix(prefix.length()); + return std::string(authorization); + } + if (params.contains("access-token")) { + return ad_utility::url_parser::getParameterCheckAtMostOnce(params, + "access-token"); + } + return std::nullopt; +} + // _____________________________________________________________________________ ad_utility::url_parser::ParsedRequest Server::parseHttpRequest( const ad_utility::httpUtils::HttpRequest auto& request) { @@ -172,7 +193,8 @@ ad_utility::url_parser::ParsedRequest Server::parseHttpRequest( using namespace ad_utility::url_parser::sparqlOperation; auto parsedUrl = ad_utility::url_parser::parseRequestTarget(request.target()); ad_utility::url_parser::ParsedRequest parsedRequest{ - std::move(parsedUrl.path_), std::move(parsedUrl.parameters_), None{}}; + std::move(parsedUrl.path_), std::nullopt, + std::move(parsedUrl.parameters_), None{}}; // Some valid requests (e.g. QLever's custom commands like retrieving index // statistics) don't have a query. So an empty operation is not necessarily an @@ -186,9 +208,14 @@ ad_utility::url_parser::ParsedRequest Server::parseHttpRequest( parsedRequest.parameters_.erase(paramName); } }; + auto extractAccessTokenFromRequest = [&parsedRequest, &request]() { + parsedRequest.accessToken_ = + extractAccessToken(request, parsedRequest.parameters_); + }; if (request.method() == http::verb::get) { setOperationIfSpecifiedInParams.template operator()("query"); + extractAccessTokenFromRequest(); if (parsedRequest.parameters_.contains("update")) { throw std::runtime_error("SPARQL Update is not allowed as GET request."); } @@ -258,15 +285,18 @@ ad_utility::url_parser::ParsedRequest Server::parseHttpRequest( } setOperationIfSpecifiedInParams.template operator()("query"); setOperationIfSpecifiedInParams.template operator()("update"); + extractAccessTokenFromRequest(); return parsedRequest; } if (contentType.starts_with(contentTypeSparqlQuery)) { parsedRequest.operation_ = Query{request.body()}; + extractAccessTokenFromRequest(); return parsedRequest; } if (contentType.starts_with(contentTypeSparqlUpdate)) { parsedRequest.operation_ = Update{request.body()}; + extractAccessTokenFromRequest(); return parsedRequest; } throw std::runtime_error(absl::StrCat( @@ -353,8 +383,7 @@ Awaitable Server::process( // Check the access token. If an access token is provided and the check fails, // throw an exception and do not process any part of the query (even if the // processing had been allowed without access token). - bool accessTokenOk = - checkAccessToken(checkParameter("access-token", std::nullopt)); + bool accessTokenOk = checkAccessToken(parsedHttpRequest.accessToken_); auto requireValidAccessToken = [&accessTokenOk]( const std::string& actionName) { if (!accessTokenOk) { diff --git a/src/engine/Server.h b/src/engine/Server.h index 3ccc070cb7..15c37a85fd 100644 --- a/src/engine/Server.h +++ b/src/engine/Server.h @@ -34,6 +34,7 @@ class Server { FRIEND_TEST(ServerTest, parseHttpRequest); FRIEND_TEST(ServerTest, getQueryId); FRIEND_TEST(ServerTest, createMessageSender); + FRIEND_TEST(ServerTest, extractAccessToken); public: explicit Server(unsigned short port, size_t numThreads, @@ -114,6 +115,12 @@ class Server { static ad_utility::url_parser::ParsedRequest parseHttpRequest( const ad_utility::httpUtils::HttpRequest auto& request); + /// Extract the Access token for that request from the `Authorization` header + /// or the URL query parameters. + static std::optional extractAccessToken( + const ad_utility::httpUtils::HttpRequest auto& request, + const ad_utility::url_parser::ParamValueMap& params); + /// Handle a single HTTP request. Check whether a file request or a query was /// sent, and dispatch to functions handling these cases. This function /// requires the constraints for the `HttpHandler` in `HttpServer.h`. diff --git a/src/util/http/UrlParser.h b/src/util/http/UrlParser.h index 33ebc86b1d..80525e8f35 100644 --- a/src/util/http/UrlParser.h +++ b/src/util/http/UrlParser.h @@ -62,10 +62,12 @@ struct None { // Representation of parsed HTTP request. // - `path_` is the URL path +// - `accessToken_` is the access token for that request // - `parameters_` is a hashmap of the parameters // - `operation_` the operation that should be performed struct ParsedRequest { std::string path_; + std::optional accessToken_; ParamValueMap parameters_; std::variant diff --git a/test/ServerTest.cpp b/test/ServerTest.cpp index 292d77f12b..685e85c7e3 100644 --- a/test/ServerTest.cpp +++ b/test/ServerTest.cpp @@ -17,11 +17,14 @@ using namespace ad_utility::url_parser::sparqlOperation; namespace { auto ParsedRequestIs = [](const std::string& path, + const std::optional& accessToken, const ParamValueMap& parameters, const std::variant& operation) -> testing::Matcher { return testing::AllOf( AD_FIELD(ad_utility::url_parser::ParsedRequest, path_, testing::Eq(path)), + AD_FIELD(ad_utility::url_parser::ParsedRequest, accessToken_, + testing::Eq(accessToken)), AD_FIELD(ad_utility::url_parser::ParsedRequest, parameters_, testing::ContainerEq(parameters)), AD_FIELD(ad_utility::url_parser::ParsedRequest, operation_, @@ -55,19 +58,21 @@ TEST(ServerTest, parseHttpRequest) { "application/x-www-form-urlencoded;charset=UTF-8"; const std::string QUERY = "application/sparql-query"; const std::string UPDATE = "application/sparql-update"; - EXPECT_THAT(parse(MakeGetRequest("/")), ParsedRequestIs("/", {}, None{})); + EXPECT_THAT(parse(MakeGetRequest("/")), + ParsedRequestIs("/", std::nullopt, {}, None{})); EXPECT_THAT(parse(MakeGetRequest("/ping")), - ParsedRequestIs("/ping", {}, None{})); + ParsedRequestIs("/ping", std::nullopt, {}, None{})); EXPECT_THAT(parse(MakeGetRequest("/?cmd=stats")), - ParsedRequestIs("/", {{"cmd", {"stats"}}}, None{})); + ParsedRequestIs("/", std::nullopt, {{"cmd", {"stats"}}}, None{})); EXPECT_THAT(parse(MakeGetRequest( "/?query=SELECT+%2A%20WHERE%20%7B%7D&action=csv_export")), - ParsedRequestIs("/", {{"action", {"csv_export"}}}, + ParsedRequestIs("/", std::nullopt, {{"action", {"csv_export"}}}, Query{"SELECT * WHERE {}"})); EXPECT_THAT( parse(MakePostRequest("/", URLENCODED, "query=SELECT+%2A%20WHERE%20%7B%7D&send=100")), - ParsedRequestIs("/", {{"send", {"100"}}}, Query{"SELECT * WHERE {}"})); + ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}}, + Query{"SELECT * WHERE {}"})); AD_EXPECT_THROW_WITH_MESSAGE( parse(MakePostRequest("/", URLENCODED, "ääär y=SELECT+%2A%20WHERE%20%7B%7D&send=100")), @@ -92,10 +97,12 @@ TEST(ServerTest, parseHttpRequest) { EXPECT_THAT( parse(MakePostRequest("/", "application/x-www-form-urlencoded", "query=SELECT%20%2A%20WHERE%20%7B%7D&send=100")), - ParsedRequestIs("/", {{"send", {"100"}}}, Query{"SELECT * WHERE {}"})); - EXPECT_THAT(parse(MakePostRequest("/", URLENCODED, - "query=SELECT%20%2A%20WHERE%20%7B%7D")), - ParsedRequestIs("/", {}, Query{"SELECT * WHERE {}"})); + ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}}, + Query{"SELECT * WHERE {}"})); + EXPECT_THAT( + parse(MakePostRequest("/", URLENCODED, + "query=SELECT%20%2A%20WHERE%20%7B%7D")), + ParsedRequestIs("/", std::nullopt, {}, Query{"SELECT * WHERE {}"})); EXPECT_THAT( parse(MakePostRequest( "/", URLENCODED, @@ -103,7 +110,7 @@ TEST(ServerTest, parseHttpRequest) { "2Fw3.org%2Fdefault&named-graph-uri=https%3A%2F%2Fw3.org%2F1&named-" "graph-uri=https%3A%2F%2Fw3.org%2F2")), ParsedRequestIs( - "/", + "/", std::nullopt, {{"default-graph-uri", {"https://w3.org/default"}}, {"named-graph-uri", {"https://w3.org/1", "https://w3.org/2"}}}, Query{"SELECT * WHERE {}"})); @@ -112,13 +119,15 @@ TEST(ServerTest, parseHttpRequest) { "query=SELECT%20%2A%20WHERE%20%7B%7D")), testing::StrEq("URL-encoded POST requests must not contain query " "parameters in the URL.")); - EXPECT_THAT(parse(MakePostRequest("/", URLENCODED, "cmd=clear-cache")), - ParsedRequestIs("/", {{"cmd", {"clear-cache"}}}, None{})); - EXPECT_THAT(parse(MakePostRequest("/", QUERY, "SELECT * WHERE {}")), - ParsedRequestIs("/", {}, Query{"SELECT * WHERE {}"})); EXPECT_THAT( - parse(MakePostRequest("/?send=100", QUERY, "SELECT * WHERE {}")), - ParsedRequestIs("/", {{"send", {"100"}}}, Query{"SELECT * WHERE {}"})); + parse(MakePostRequest("/", URLENCODED, "cmd=clear-cache")), + ParsedRequestIs("/", std::nullopt, {{"cmd", {"clear-cache"}}}, None{})); + EXPECT_THAT( + parse(MakePostRequest("/", QUERY, "SELECT * WHERE {}")), + ParsedRequestIs("/", std::nullopt, {}, Query{"SELECT * WHERE {}"})); + EXPECT_THAT(parse(MakePostRequest("/?send=100", QUERY, "SELECT * WHERE {}")), + ParsedRequestIs("/", std::nullopt, {{"send", {"100"}}}, + Query{"SELECT * WHERE {}"})); AD_EXPECT_THROW_WITH_MESSAGE( parse(MakeBasicRequest(http::verb::patch, "/")), testing::StrEq( @@ -132,14 +141,35 @@ TEST(ServerTest, parseHttpRequest) { AD_EXPECT_THROW_WITH_MESSAGE( parse(MakeGetRequest("/?update=DELETE%20%2A%20WHERE%20%7B%7D")), testing::StrEq("SPARQL Update is not allowed as GET request.")); - EXPECT_THAT(parse(MakePostRequest("/", UPDATE, "DELETE * WHERE {}")), - ParsedRequestIs("/", {}, Update{"DELETE * WHERE {}"})); - EXPECT_THAT(parse(MakePostRequest("/", URLENCODED, - "update=DELETE%20%2A%20WHERE%20%7B%7D")), - ParsedRequestIs("/", {}, Update{"DELETE * WHERE {}"})); - EXPECT_THAT(parse(MakePostRequest("/", URLENCODED, - "update=DELETE+%2A+WHERE%20%7B%7D")), - ParsedRequestIs("/", {}, Update{"DELETE * WHERE {}"})); + EXPECT_THAT( + parse(MakePostRequest("/", UPDATE, "DELETE * WHERE {}")), + ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}"})); + EXPECT_THAT( + parse(MakePostRequest("/", URLENCODED, + "update=DELETE%20%2A%20WHERE%20%7B%7D")), + ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}"})); + EXPECT_THAT( + parse( + MakePostRequest("/", URLENCODED, "update=DELETE+%2A+WHERE%20%7B%7D")), + ParsedRequestIs("/", std::nullopt, {}, Update{"DELETE * WHERE {}"})); + // TODO: there could be some more here, but i'll wait until #1668 + EXPECT_THAT( + parse(MakeGetRequest("/?query=a&access-token=foo")), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, Query{"a"})); + EXPECT_THAT( + parse(MakePostRequest("/", URLENCODED, + "update=DELETE%20WHERE%20%7B%7D&access-token=foo")), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + Update{"DELETE WHERE {}"})); + EXPECT_THAT(parse(MakePostRequest( + "/", URLENCODED, + "query=SELECT%20%2A%20WHERE%20%7B%7D&access-token=foo")), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + Query{"SELECT * WHERE {}"})); + EXPECT_THAT( + parse(MakePostRequest("/?access-token=foo", UPDATE, "DELETE * WHERE {}")), + ParsedRequestIs("/", "foo", {{"access-token", {"foo"}}}, + Update{"DELETE * WHERE {}"})); } TEST(ServerTest, checkParameter) { @@ -284,3 +314,57 @@ TEST(ServerTest, createMessageSender) { "SELECT * WHERE { ?a ?b ?c }"), testing::HasSubstr("Assertion `queryHubLock` failed.")); } + +TEST(ServerTest, extractAccessToken) { + auto extract = [](const ad_utility::httpUtils::HttpRequest auto& request) { + auto parsedUrl = parseRequestTarget(request.target()); + return Server::extractAccessToken(request, parsedUrl.parameters_); + }; + // TODO: replace once #1668 is merged + auto makeRequest = + [](const http::verb method = http::verb::get, + const std::string& target = "/", + const ad_utility::HashMap& headers = {}, + const std::optional& body = std::nullopt) { + // version 11 stands for HTTP/1.1 + auto req = http::request{method, target, 11}; + for (const auto& [key, value] : headers) { + req.set(key, value); + } + if (body.has_value()) { + req.body() = body.value(); + req.prepare_payload(); + } + return req; + }; + EXPECT_THAT(extract(MakeGetRequest("/")), testing::Eq(std::nullopt)); + EXPECT_THAT(extract(MakeGetRequest("/?access-token=foo")), + testing::Optional(testing::Eq("foo"))); + EXPECT_THAT( + extract(makeRequest(http::verb::get, "/", + {{http::field::authorization, "Bearer foo"}})), + testing::Optional(testing::Eq("foo"))); + // The header takes precedence over the query parameter. + EXPECT_THAT( + extract(makeRequest(http::verb::get, "/?access-token=bar", + {{http::field::authorization, "Bearer foo"}})), + testing::Optional(testing::Eq("foo"))); + AD_EXPECT_THROW_WITH_MESSAGE( + extract(makeRequest(http::verb::get, "/", + {{http::field::authorization, "foo"}})), + testing::HasSubstr( + "Authorization header must start with \"Bearer \". Got: \"foo\".")); + EXPECT_THAT(extract(MakePostRequest("/", "text/turtle", "")), + testing::Eq(std::nullopt)); + EXPECT_THAT(extract(MakePostRequest("/?access-token=foo", "text/turtle", "")), + testing::Optional(testing::Eq("foo"))); + EXPECT_THAT( + extract(makeRequest(http::verb::post, "/?access-token=bar", + {{http::field::authorization, "Bearer foo"}})), + testing::Optional(testing::Eq("foo"))); + AD_EXPECT_THROW_WITH_MESSAGE( + extract(makeRequest(http::verb::post, "/?access-token=bar", + {{http::field::authorization, "foo"}})), + testing::HasSubstr( + "Authorization header must start with \"Bearer \". Got: \"foo\".")); +}