Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:add request check callback #384

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions include/cinatra/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,19 @@ class connection : public base_connection,
asio::io_service &io_service, ssl_configure ssl_conf,
std::size_t max_req_size, long keep_alive_timeout, http_handler &handler,
std::string &static_dir,
std::function<bool(request &req, response &res)> *upload_check)
std::function<bool(request &req, response &res)> *upload_check,
std::string &checker_resp,
std::function<bool(request &req, response &res)> *req_body_check)
: socket_(io_service),
MAX_REQ_SIZE_(max_req_size),
KEEP_ALIVE_TIMEOUT_(keep_alive_timeout),
timer_(io_service),
http_handler_(handler),
req_(res_),
static_dir_(static_dir),
upload_check_(upload_check) {
upload_check_(upload_check),
checker_resp_(checker_resp),
req_body_check_(req_body_check) {
if constexpr (is_ssl_) {
init_ssl_context(std::move(ssl_conf));
}
Expand Down Expand Up @@ -442,7 +446,16 @@ class connection : public base_connection,
void handle_request(std::size_t bytes_transferred) {
auto type = get_content_type();
req_.set_http_type(type);

if (req_.has_body()) {
if (req_body_check_) {
bool r = (*req_body_check_)(req_, res_);
if (!r) {
response_back(status_type::entity_too_large,
std::move(checker_resp_));
return;
}
}
switch (type) {
case cinatra::content_type::string:
case cinatra::content_type::websocket:
Expand Down Expand Up @@ -1470,6 +1483,8 @@ class connection : public base_connection,
// callback handler to application layer
const http_handler &http_handler_;
std::function<bool(request &req, response &res)> *upload_check_ = nullptr;
std::string checker_resp_;
std::function<bool(request &req, response &res)> *req_body_check_ = nullptr;
std::any tag_;
std::function<void(request &, std::string &)> multipart_begin_ = nullptr;

Expand Down
12 changes: 11 additions & 1 deletion include/cinatra/http_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,13 @@ class http_server_ : private noncopyable {
upload_check_ = std::move(checker);
}

// should be called before listen
void set_body_check(std::function<bool(request &req, response &res)> checker,
std::string resp_str = "") {
checker_resp_ = resp_str;
req_body_check_ = std::move(checker);
}

void mapping_to_root_path(std::string relate_path) {
relate_paths_.emplace_back("." + std::move(relate_path));
}
Expand Down Expand Up @@ -280,7 +287,8 @@ class http_server_ : private noncopyable {
auto new_conn = std::make_shared<connection<ScoketType>>(
io_service_pool_.get_io_service(), ssl_conf_, max_req_buf_size_,
keep_alive_timeout_, http_handler_, upload_dir_,
upload_check_ ? &upload_check_ : nullptr);
upload_check_ ? &upload_check_ : nullptr, checker_resp_,
req_body_check_ ? &req_body_check_ : nullptr);

acceptor_->async_accept(
new_conn->tcp_socket(), [this, new_conn](const std::error_code &e) {
Expand Down Expand Up @@ -603,6 +611,8 @@ class http_server_ : private noncopyable {
std::function<bool(request &req, response &res)> download_check_;
std::vector<std::string> relate_paths_;
std::function<bool(request &req, response &res)> upload_check_ = nullptr;
std::string checker_resp_;
std::function<bool(request &req, response &res)> req_body_check_ = nullptr;

std::function<void(request &req, response &res)> not_found_ = nullptr;
std::function<void(request &, std::string &)> multipart_begin_ = nullptr;
Expand Down
15 changes: 15 additions & 0 deletions include/cinatra/response_cv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum class status_type {
forbidden = 403,
not_found = 404,
conflict = 409,
entity_too_large = 413,
internal_server_error = 500,
not_implemented = 501,
bad_gateway = 502,
Expand Down Expand Up @@ -109,6 +110,12 @@ inline std::string_view conflict =
"<body><h1>409 Conflict</h1></body>"
"</html>";

inline std::string_view entity_too_large =
"<html>"
"<head><title>RequestEntityTooLarge</title></head>"
"<body><h1>413 RequestEntityTooLarge</h1></body>"
"</html>";

inline std::string_view internal_server_error =
"<html>"
"<head><title>Internal Server Error</title></head>"
Expand Down Expand Up @@ -159,6 +166,7 @@ inline constexpr std::string_view rep_unauthorized =
inline constexpr std::string_view rep_forbidden = "HTTP/1.1 403 Forbidden\r\n";
inline constexpr std::string_view rep_not_found = "HTTP/1.1 404 Not Found\r\n";
inline constexpr std::string_view rep_conflict = "HTTP/1.1 409 Conflict\r\n";
inline constexpr std::string_view rep_entity_too_large = "HTTP/1.1 413 Request Too Large\r\n";
inline constexpr std::string_view rep_internal_server_error =
"HTTP/1.1 500 Internal Server Error\r\n";
inline constexpr std::string_view rep_not_implemented =
Expand Down Expand Up @@ -289,6 +297,8 @@ inline decltype(auto) to_buffer(status_type status) {
return T(rep_not_found.data(), rep_not_found.length());
case status_type::conflict:
return T(rep_conflict.data(), rep_conflict.length());
case status_type::entity_too_large:
return T(rep_entity_too_large.data(), rep_entity_too_large.length());
case status_type::internal_server_error:
return T(rep_internal_server_error.data(),
rep_internal_server_error.length());
Expand Down Expand Up @@ -355,6 +365,9 @@ inline constexpr std::string_view to_rep_string(status_type status) {
case cinatra::status_type::conflict:
return rep_conflict;
break;
case cinatra::status_type::entity_too_large:
return rep_entity_too_large;
break;
case cinatra::status_type::internal_server_error:
return rep_internal_server_error;
break;
Expand Down Expand Up @@ -403,6 +416,8 @@ inline std::string_view to_string(status_type status) {
return not_found;
case status_type::conflict:
return conflict;
case status_type::entity_too_large:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no ut for 413?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no ut for 413?

coro client get code is 404,don't know why?postman test it that return http code is 413

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

413 test
image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's can be merged. drogon framework has function setClientMaxBodySize to limit client request max body.We can use this.
image

return entity_too_large;
case status_type::internal_server_error:
return internal_server_error;
case status_type::not_implemented:
Expand Down
81 changes: 81 additions & 0 deletions tests/test_cinatra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,4 +943,85 @@ TEST_CASE(
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den
<< "s" << std::endl;
}

TEST_CASE("test server body limit upload request body") {
http_server server(std::thread::hardware_concurrency());
// server.enable_timeout(false);

auto control_func = [](request &req, response &res) -> bool {
int max_body_size = 2048; // 1k
if (max_body_size < req.body_len()) {
return false;
}
return true;
};

server.set_body_check(control_func, "Request Entity Is Too Large");

bool r = server.listen("0.0.0.0", "8090");
if (!r) {
std::cout << "listen failed."
<< "\n";
}

server.set_http_handler<POST>("/multipart", [](request &req, response &res) {
assert(req.get_content_type() == content_type::multipart);
auto &files = req.get_upload_files();
for (auto &file : files) {
std::cout << file.get_file_path() << " " << file.get_file_size()
<< std::endl;
}
std::cout << "multipart finished\n";
res.render_string("multipart finished");
});

std::promise<void> pr;
std::future<void> f = pr.get_future();
std::thread server_thread([&server, &pr]() {
pr.set_value();
server.run();
});
f.wait();

std::this_thread::sleep_for(std::chrono::milliseconds(100));

coro_http_client client{};
std::string uri = "http://127.0.0.1:8090/multipart";

client.set_max_single_part_size(2048);
std::string test_file_name = "test1.txt";
std::ofstream test_file;
test_file.open(test_file_name,
std::ios::binary | std::ios::out | std::ios::trunc);
std::vector<char> test_file_data(10 * 1024 * 1024, '0');
test_file.write(test_file_data.data(), test_file_data.size());
test_file.close();
auto result = async_simple::coro::syncAwait(
client.async_upload_multipart(uri, "test", test_file_name));

CHECK(result.status == 404);

coro_http_client sec_client{};
sec_client.set_max_single_part_size(512);
std::string small_file_name = "test2.txt";
std::ofstream small_file;
small_file.open(small_file_name,
std::ios::binary | std::ios::out | std::ios::trunc);
std::vector<char> small_file_data(512, '0');
small_file.write(small_file_data.data(), small_file_data.size());
small_file.close();
result = async_simple::coro::syncAwait(
sec_client.async_upload_multipart(uri, "test", small_file_name));
CHECK(result.status == 200);
CHECK(result.resp_body == "multipart finished");

std::filesystem::remove(std::filesystem::path(test_file_name));
std::filesystem::remove(std::filesystem::path(small_file_name));

client.async_close();
sec_client.async_close();

server.stop();
server_thread.join();
}
Loading