Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: wait ws coroutine quit #380

Merged
merged 10 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions include/cinatra/coro_http_client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class coro_http_client {

coro_http_client(asio::io_context::executor_type executor)
: socket_(std::make_shared<socket_t>(executor)),
read_buf_(socket_->read_buf_),
executor_wrapper_(executor),
timer_(&executor_wrapper_) {}

Expand Down Expand Up @@ -1017,6 +1018,7 @@ class coro_http_client {
struct socket_t {
asio::ip::tcp::socket impl_;
std::atomic<bool> has_closed_ = true;
asio::streambuf read_buf_;
template <typename ioc_t>
socket_t(ioc_t &&ioc) : impl_(std::forward<ioc_t>(ioc)) {}
};
Expand Down Expand Up @@ -1510,24 +1512,22 @@ class coro_http_client {

read_buf_.consume(read_buf_.size());
size_t header_size = 2;

std::shared_ptr sock = socket_;
auto on_ws_msg = std::move(on_ws_msg_);
websocket ws{};
while (true) {
std::weak_ptr socket = socket_;
if (auto [ec, _] = co_await async_read(read_buf_, header_size); ec) {
data.net_err = ec;
data.status = 404;
auto sock = socket.lock();
if (!sock) {

if (sock->has_closed_) {
co_return;
}
if (!sock->has_closed_) {
close_socket(*sock);
}

if (on_ws_msg_)
on_ws_msg_(data);
close_socket(*sock);

if (on_ws_msg)
on_ws_msg(data);
co_return;
}

Expand Down Expand Up @@ -1669,10 +1669,9 @@ class coro_http_client {
}

coro_io::ExecutorWrapper<> executor_wrapper_;
std::unique_ptr<asio::io_context::work> work_;
coro_io::period_timer timer_;
std::shared_ptr<socket_t> socket_;
asio::streambuf read_buf_;
asio::streambuf &read_buf_;
simple_buffer body_{};

std::unordered_map<std::string, std::string> req_headers_;
Expand Down Expand Up @@ -1710,6 +1709,7 @@ class coro_http_client {
std::chrono::steady_clock::duration req_timeout_duration_ =
std::chrono::seconds(60);
std::string resp_chunk_str_;

#ifdef BENCHMARK_TEST
std::string req_str_;
bool stop_bench_ = false;
Expand Down
33 changes: 6 additions & 27 deletions tests/test_cinatra_websocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,14 @@ TEST_CASE("test wss client") {
}
#endif

async_simple::coro::Lazy<void> test_websocket(coro_http_client &client,
std::promise<void> &promise) {
async_simple::coro::Lazy<void> test_websocket(coro_http_client &client) {
client.on_ws_close([](std::string_view reason) {
std::cout << "web socket close " << reason << std::endl;
CHECK(reason == "ws close");
});
client.on_ws_msg([&](resp_data data) {
if (data.net_err) {
std::cout << data.net_err.message() << "\n";
promise.set_value();
return;
}

Expand Down Expand Up @@ -137,11 +135,9 @@ TEST_CASE("test websocket") {
coro_http_client client;
client.set_ws_sec_key("s//GYHa/XO7Hd2F2eOGfyA==");

std::promise<void> promise;
async_simple::coro::syncAwait(test_websocket(client, promise));
async_simple::coro::syncAwait(test_websocket(client));

client.async_close();
promise.get_future().wait();

std::this_thread::sleep_for(std::chrono::milliseconds(300));

Expand Down Expand Up @@ -176,39 +172,27 @@ void test_websocket_content(size_t len) {
REQUIRE(async_simple::coro::syncAwait(
client.async_ws_connect("ws://localhost:8090")));

std::pair<std::promise<void>, bool> msg_pair_promise{};

std::string send_str(len, 'a');

std::promise<void> quit_promise{};

client.on_ws_msg([&, send_str](resp_data data) {
if (data.net_err) {
std::cout << "ws_msg net error " << data.net_err.message() << "\n";
quit_promise.set_value();
if (!msg_pair_promise.second) {
msg_pair_promise.first.set_value();
}

return;
}

std::cout << "ws msg len: " << data.resp_body.size() << std::endl;
REQUIRE(data.resp_body.size() == send_str.size());
CHECK(data.resp_body == send_str);
msg_pair_promise.first.set_value();
msg_pair_promise.second = true;
});

async_simple::coro::syncAwait(client.async_send_ws(send_str));
msg_pair_promise.first.get_future().wait();

std::this_thread::sleep_for(std::chrono::milliseconds(300));

server.stop();
server_thread.join();

client.async_close();

quit_promise.get_future().wait();
}

TEST_CASE("test websocket content lt 126") {
Expand Down Expand Up @@ -243,12 +227,8 @@ TEST_CASE("test send after server stop") {
REQUIRE(async_simple::coro::syncAwait(
client->async_ws_connect("ws://localhost:8090")));

std::promise<void> promise;
client->on_ws_msg([&client, &promise](resp_data data) {
if (data.net_err) {
client->async_close();
}
promise.set_value();
client->on_ws_msg([](resp_data data) {
std::cout << data.net_err.message() << "\n";
});

server.stop();
Expand All @@ -259,5 +239,4 @@ TEST_CASE("test send after server stop") {
CHECK(result.net_err);

server_thread.join();
promise.get_future().wait();
}
Loading