From aa1418ad12cc6913b507bf16de79319560c968b2 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Wed, 16 Oct 2024 20:22:42 -0400 Subject: [PATCH] Websocket support (#124) * Refactor WebSocket client code and add WebSocket support to request handler * Refactor WebSocket client code and remove unnecessary forward declarations * Refactor WebSocket client code and add WebSocket support to request handler * Refactor UI layout and adjust tab stop distance --- CMakeLists.txt | 3 + cmake/FetchWebsocketpp.cmake | 57 ++++++ src/request-data.cpp | 351 +++++++++++++++++++---------------- src/request-data.h | 15 ++ src/ui/requestbuilder.ui | 117 +++++------- src/websocket-client.cpp | 193 +++++++++++++++++++ src/websocket-client.h | 9 + 7 files changed, 514 insertions(+), 231 deletions(-) create mode 100644 cmake/FetchWebsocketpp.cmake create mode 100644 src/websocket-client.cpp create mode 100644 src/websocket-client.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 3afb1e0..cb0a919 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,6 +68,8 @@ target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE inja) include(cmake/BuildLexbor.cmake) target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE liblexbor_internal) +include(cmake/FetchWebsocketpp.cmake) + target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE vendor/nlohmann-json) target_sources( @@ -76,6 +78,7 @@ target_sources( src/obs-source-util.cpp src/mapping-data.cpp src/request-data.cpp + src/websocket-client.cpp src/ui/CustomTextDocument.cpp src/ui/RequestBuilder.cpp src/ui/text-render-helper.cpp diff --git a/cmake/FetchWebsocketpp.cmake b/cmake/FetchWebsocketpp.cmake new file mode 100644 index 0000000..a8c4033 --- /dev/null +++ b/cmake/FetchWebsocketpp.cmake @@ -0,0 +1,57 @@ +if(WIN32 OR APPLE) + # Windows and macOS are supported by the prebuilt dependencies + + if(NOT buildspec) + file(READ "${CMAKE_SOURCE_DIR}/buildspec.json" buildspec) + endif() + + string( + JSON + version + GET + ${buildspec} + dependencies + prebuilt + version) + + if(MSVC) + set(arch ${CMAKE_GENERATOR_PLATFORM}) + elseif(APPLE) + set(arch universal) + endif() + + set(deps_root "${CMAKE_SOURCE_DIR}/.deps/obs-deps-${version}-${arch}") + target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE "${deps_root}/include") +else() + # Linux requires fetching the dependencies + include(FetchContent) + + FetchContent_Declare( + websocketpp + URL https://github.com/zaphoyd/websocketpp/archive/refs/tags/0.8.2.tar.gz + URL_HASH SHA256=6ce889d85ecdc2d8fa07408d6787e7352510750daa66b5ad44aacb47bea76755) + + # Only download the content, don't configure or build it + FetchContent_GetProperties(websocketpp) + + if(NOT websocketpp_POPULATED) + FetchContent_Populate(websocketpp) + endif() + + # Add WebSocket++ as an interface library + add_library(websocketpp INTERFACE) + target_include_directories(websocketpp INTERFACE ${websocketpp_SOURCE_DIR}) + + # Fetch ASIO + FetchContent_Declare( + asio + URL https://github.com/chriskohlhoff/asio/archive/asio-1-28-0.tar.gz + URL_HASH SHA256=226438b0798099ad2a202563a83571ce06dd13b570d8fded4840dbc1f97fa328) + + FetchContent_MakeAvailable(websocketpp asio) + + target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE websocketpp) + target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE ${asio_SOURCE_DIR}/asio/include/) +endif() + +message(STATUS "WebSocket++ and ASIO have been added to the project") diff --git a/src/request-data.cpp b/src/request-data.cpp index a494029..bfb5810 100644 --- a/src/request-data.cpp +++ b/src/request-data.cpp @@ -24,6 +24,7 @@ #include #include "obs-source-util.h" +#include "websocket-client.h" #define URL_SOURCE_AGG_BUFFER_MAX_SIZE 1024 @@ -164,7 +165,7 @@ void handle_empty_text(input_data &input, request_data_handler_response &respons } } -void put_inputs_on_json(url_source_request_data *request_data, CURL *curl, +void put_inputs_on_json(url_source_request_data *request_data, request_data_handler_response &response, nlohmann::json &json) { for (size_t i = 0; i < request_data->inputs.size(); i++) { @@ -225,7 +226,6 @@ void put_inputs_on_json(url_source_request_data *request_data, CURL *curl, } } if (response.status_code == URL_SOURCE_REQUEST_BENIGN_ERROR_CODE) { - curl_easy_cleanup(curl); return; } } else { @@ -237,7 +237,6 @@ void put_inputs_on_json(url_source_request_data *request_data, CURL *curl, // Return an error response response.error_message = "Failed to get source by name"; response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; - curl_easy_cleanup(curl); return; } @@ -262,7 +261,6 @@ void put_inputs_on_json(url_source_request_data *request_data, CURL *curl, // Return an error response response.error_message = "Failed to get RGBA from source render"; response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; - curl_easy_cleanup(curl); return; } destroy_source_render_data(&tf); @@ -276,199 +274,222 @@ void put_inputs_on_json(url_source_request_data *request_data, CURL *curl, } } -struct request_data_handler_response request_data_handler(url_source_request_data *request_data) +void prepare_inja_env(inja::Environment *env, url_source_request_data *request_data, + request_data_handler_response &response, nlohmann::json &json) { - struct request_data_handler_response response; + // Put the request inputs on the json object + put_inputs_on_json(request_data, response, json); - request_data->sequence_number++; + if (response.status_code != URL_SOURCE_REQUEST_SUCCESS) { + return; + } - if (request_data->url_or_file == "file") { - // This is a file request - // Read the file - std::ifstream file(request_data->url); - if (!file.is_open()) { - obs_log(LOG_INFO, "Failed to open file"); - // Return an error response - response.error_message = "Failed to open file"; - response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; - return response; - } - std::string responseBody((std::istreambuf_iterator(file)), - std::istreambuf_iterator()); - file.close(); + // Add an inja callback for time formatting + env->add_callback("strftime", 2, [](inja::Arguments &args) { + std::string format = args.at(0)->get(); + std::time_t t = std::time(nullptr); + std::tm *tm = std::localtime(&t); + if (args.at(1)->get()) { + // if the second argument is true, use UTC time + tm = std::gmtime(&t); + } + char buffer[256]; + std::strftime(buffer, sizeof(buffer), format.c_str(), tm); + return std::string(buffer); + }); + + json["seq"] = request_data->sequence_number; +} - response.body = responseBody; - response.status_code = URL_SOURCE_REQUEST_SUCCESS; +request_data_handler_response http_request_handler(url_source_request_data *request_data, + request_data_handler_response &response) +{ + // Build the request with libcurl + CURL *curl = curl_easy_init(); + if (!curl) { + obs_log(LOG_INFO, "Failed to initialize curl"); + // Return an error response + response.error_message = "Failed to initialize curl"; + response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; + return response; + } + curl_easy_setopt(curl, CURLOPT_USERAGENT, USER_AGENT.c_str()); + if (request_data->fail_on_http_error) { + curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); + } + + std::string responseBody; + std::vector responseBodyUint8; + + // if the request is for textual data write to string + if (request_data->output_type == "JSON" || request_data->output_type == "XML (XPath)" || + request_data->output_type == "XML (XQuery)" || request_data->output_type == "HTML" || + request_data->output_type == "Text") { + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseBody); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeFunctionStdString); } else { - // This is a URL request - // Check if the URL is empty - if (request_data->url == "") { - obs_log(LOG_INFO, "URL is empty"); - // Return an error response - response.error_message = "URL is empty"; - response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; - return response; - } + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseBodyUint8); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeFunctionUint8Vector); + } - // Build the request with libcurl - CURL *curl = curl_easy_init(); - if (!curl) { - obs_log(LOG_INFO, "Failed to initialize curl"); - // Return an error response - response.error_message = "Failed to initialize curl"; - response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; - return response; - } - curl_easy_setopt(curl, CURLOPT_USERAGENT, USER_AGENT.c_str()); - if (request_data->fail_on_http_error) { - curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); + if (request_data->headers.size() > 0) { + // Add request headers + struct curl_slist *headers = NULL; + for (auto header : request_data->headers) { + std::string header_string = header.first + ": " + header.second; + headers = curl_slist_append(headers, header_string.c_str()); } + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + } - std::string responseBody; - std::vector responseBodyUint8; + nlohmann::json json; // json object or variables for inja + inja::Environment env; + prepare_inja_env(&env, request_data, response, json); - // if the request is for textual data write to string - if (request_data->output_type == "JSON" || - request_data->output_type == "XML (XPath)" || - request_data->output_type == "XML (XQuery)" || - request_data->output_type == "HTML" || request_data->output_type == "Text") { - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseBody); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeFunctionStdString); - } else { - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseBodyUint8); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeFunctionUint8Vector); - } + if (response.status_code != URL_SOURCE_REQUEST_SUCCESS) { + curl_easy_cleanup(curl); + return response; + } - if (request_data->headers.size() > 0) { - // Add request headers - struct curl_slist *headers = NULL; - for (auto header : request_data->headers) { - std::string header_string = header.first + ": " + header.second; - headers = curl_slist_append(headers, header_string.c_str()); - } - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); - } + // add a callback for escaping strings in the querystring + env.add_callback("urlencode", 1, [&curl](inja::Arguments &args) { + std::string input = args.at(0)->get(); + char *escaped = curl_easy_escape(curl, input.c_str(), 0); + input = std::string(escaped); + curl_free(escaped); + return input; + }); + + // Replace the {input} placeholder in the querystring as well + std::string url = request_data->url; + try { + url = env.render(url, json); + } catch (std::exception &e) { + obs_log(LOG_WARNING, "Failed to render URL template: %s", e.what()); + } + response.request_url = url; - nlohmann::json json; // json object or variables for inja + // validate the url + if (!hasOnlyValidURLCharacters(url)) { + obs_log(LOG_INFO, "URL '%s' is invalid", url.c_str()); + // Return an error response + response.error_message = "URL is invalid"; + response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; + return response; + } - // Put the request inputs on the json object - put_inputs_on_json(request_data, curl, response, json); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - if (response.status_code != URL_SOURCE_REQUEST_SUCCESS) { - curl_easy_cleanup(curl); - return response; - } + // this is needed here, out of the `if` scope below + std::string request_body_allocated; - // Replace placeholders in the URL and body with the input values - inja::Environment env; - // Add an inja callback for time formatting - env.add_callback("strftime", 2, [](inja::Arguments &args) { - std::string format = args.at(0)->get(); - std::time_t t = std::time(nullptr); - std::tm *tm = std::localtime(&t); - if (args.at(1)->get()) { - // if the second argument is true, use UTC time - tm = std::gmtime(&t); - } - char buffer[256]; - std::strftime(buffer, sizeof(buffer), format.c_str(), tm); - return std::string(buffer); - }); - // add a callback for escaping strings in the querystring - env.add_callback("urlencode", 1, [&curl](inja::Arguments &args) { - std::string input = args.at(0)->get(); - char *escaped = curl_easy_escape(curl, input.c_str(), 0); - input = std::string(escaped); - curl_free(escaped); - return input; - }); - - json["seq"] = request_data->sequence_number; - - // Replace the {input} placeholder in the querystring as well - std::string url = request_data->url; + if (request_data->method == "POST") { + curl_easy_setopt(curl, CURLOPT_POST, 1L); try { - url = env.render(url, json); + request_body_allocated = env.render(request_data->body, json); } catch (std::exception &e) { - obs_log(LOG_WARNING, "Failed to render URL template: %s", e.what()); + obs_log(LOG_WARNING, "Failed to render Body template: %s", e.what()); } - response.request_url = url; + response.request_body = request_body_allocated; + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_body_allocated.c_str()); + } else if (request_data->method == "GET") { + curl_easy_setopt(curl, CURLOPT_HTTPGET, 1L); + } - // validate the url - if (!hasOnlyValidURLCharacters(url)) { - obs_log(LOG_INFO, "URL '%s' is invalid", url.c_str()); - // Return an error response - response.error_message = "URL is invalid"; - response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; - return response; - } + // SSL options + if (request_data->ssl_client_cert_file != "") { + curl_easy_setopt(curl, CURLOPT_SSLCERT, request_data->ssl_client_cert_file.c_str()); + } + if (request_data->ssl_client_key_file != "") { + curl_easy_setopt(curl, CURLOPT_SSLKEY, request_data->ssl_client_key_file.c_str()); + } + if (request_data->ssl_client_key_pass != "") { + curl_easy_setopt(curl, CURLOPT_SSLKEYPASSWD, + request_data->ssl_client_key_pass.c_str()); + } + if (!request_data->ssl_verify_peer) { + curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0L); + } - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + std::map headers; + curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, header_callback); + curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers); - // this is needed here, out of the `if` scope below - std::string request_body_allocated; + // Send the request + CURLcode code = curl_easy_perform(curl); + long http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + curl_easy_cleanup(curl); - if (request_data->method == "POST") { - curl_easy_setopt(curl, CURLOPT_POST, 1L); - try { - request_body_allocated = env.render(request_data->body, json); - } catch (std::exception &e) { - obs_log(LOG_WARNING, "Failed to render Body template: %s", - e.what()); - } - response.request_body = request_body_allocated; - curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_body_allocated.c_str()); - } else if (request_data->method == "GET") { - curl_easy_setopt(curl, CURLOPT_HTTPGET, 1L); - } + response.body = responseBody; + response.body_bytes = responseBodyUint8; + response.headers = headers; + response.http_status_code = http_code; - // SSL options - if (request_data->ssl_client_cert_file != "") { - curl_easy_setopt(curl, CURLOPT_SSLCERT, - request_data->ssl_client_cert_file.c_str()); - } - if (request_data->ssl_client_key_file != "") { - curl_easy_setopt(curl, CURLOPT_SSLKEY, - request_data->ssl_client_key_file.c_str()); - } - if (request_data->ssl_client_key_pass != "") { - curl_easy_setopt(curl, CURLOPT_SSLKEYPASSWD, - request_data->ssl_client_key_pass.c_str()); - } - if (!request_data->ssl_verify_peer) { - curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0L); + if (code != CURLE_OK) { + obs_log(LOG_WARNING, "Failed to send request to '%s': %s", url.c_str(), + curl_easy_strerror(code)); + if (responseBody.size() > 0) { + obs_log(LOG_WARNING, "Response body: %s", responseBody.c_str()); } + // Return a formatted error response with the message and the HTTP status code + response.error_message = std::string(curl_easy_strerror(code)) + " (" + + std::to_string(http_code) + ")"; - std::map headers; - curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, header_callback); - curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers); + response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; + return response; + } - // Send the request - CURLcode code = curl_easy_perform(curl); - long http_code = 0; - curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); - curl_easy_cleanup(curl); + response.status_code = URL_SOURCE_REQUEST_SUCCESS; - response.body = responseBody; - response.body_bytes = responseBodyUint8; - response.headers = headers; - response.http_status_code = http_code; - - if (code != CURLE_OK) { - obs_log(LOG_WARNING, "Failed to send request to '%s': %s", url.c_str(), - curl_easy_strerror(code)); - if (responseBody.size() > 0) { - obs_log(LOG_WARNING, "Response body: %s", responseBody.c_str()); - } - // Return a formatted error response with the message and the HTTP status code - response.error_message = std::string(curl_easy_strerror(code)) + " (" + - std::to_string(http_code) + ")"; + return response; +} +struct request_data_handler_response request_data_handler(url_source_request_data *request_data) +{ + struct request_data_handler_response response; + + // Check if the URL is empty + if (request_data->url == "") { + obs_log(LOG_INFO, "URL is empty"); + // Return an error response + response.error_message = "URL is empty"; + response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; + return response; + } + + request_data->sequence_number++; + + if (request_data->url_or_file == "file") { + // This is a file request + // Read the file + std::ifstream file(request_data->url); + if (!file.is_open()) { + obs_log(LOG_INFO, "Failed to open file"); + // Return an error response + response.error_message = "Failed to open file"; response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; return response; } + std::string responseBody((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + file.close(); + response.body = responseBody; response.status_code = URL_SOURCE_REQUEST_SUCCESS; + } else { + // This is a URL request + if (request_data->method == "WebSocket") { + // This is a websocket request + response = websocket_request_handler(request_data); + } else { + // This is an HTTP request + response = http_request_handler(request_data, response); + } + + if (response.status_code != URL_SOURCE_REQUEST_SUCCESS) { + return response; + } } // Parse the response diff --git a/src/request-data.h b/src/request-data.h index ac11339..8e5f3a5 100644 --- a/src/request-data.h +++ b/src/request-data.h @@ -83,6 +83,8 @@ inline int url_source_agg_target_string_to_enum(const std::string &agg_target) } } +struct WebSocketClientWrapper; // Forward declaration + struct url_source_request_data { std::string source_name; std::string url; @@ -117,6 +119,11 @@ struct url_source_request_data { std::string post_process_regex_replace; std::string kv_delimiter; + // WebSocket-specific fields + bool is_websocket; + WebSocketClientWrapper *ws_client_wrapper; + bool ws_connected; + // default constructor url_source_request_data() { @@ -143,6 +150,7 @@ struct url_source_request_data { post_process_regex_is_replace = false; post_process_regex_replace = std::string(""); kv_delimiter = std::string("="); + ws_client_wrapper = nullptr; } }; @@ -162,6 +170,13 @@ struct request_data_handler_response { std::string request_body; }; +namespace inja { +class Environment; +} + +void prepare_inja_env(inja::Environment *env, url_source_request_data *request_data, + request_data_handler_response &response, nlohmann::json &json); + struct request_data_handler_response request_data_handler(url_source_request_data *request_data); std::string serialize_request_data(url_source_request_data *request_data); diff --git a/src/ui/requestbuilder.ui b/src/ui/requestbuilder.ui index 2dcdef2..609c857 100644 --- a/src/ui/requestbuilder.ui +++ b/src/ui/requestbuilder.ui @@ -6,12 +6,12 @@ 0 0 - 622 - 1134 + 498 + 833 - + 0 0 @@ -31,16 +31,10 @@ 0 - + QLayout::SetDefaultConstraint - - QFormLayout::ExpandingFieldsGrow - - - 3 - 0 @@ -53,19 +47,29 @@ 0 - - + + + + + 0 + 0 + + - Source + URL - - - - - 0 - + + + + File + + + + + + 0 @@ -79,38 +83,24 @@ 0 - - - - 0 - 0 - - - - URL - - + - + - File + ... - - - - URL/File - - - - - - + + + + + 0 + 0 @@ -123,16 +113,6 @@ 0 - - - - - - - ... - - - @@ -141,14 +121,8 @@ - - - 0 - 0 - - - - URL Request Options + + true @@ -160,6 +134,9 @@ 0 + + 0 + @@ -197,6 +174,11 @@ POST + + + WebSocket + + @@ -385,6 +367,15 @@ + + + 0 + 0 + + + + 20.000000000000000 + false @@ -538,12 +529,6 @@ - - - 0 - 0 - - Output Parsing Options diff --git a/src/websocket-client.cpp b/src/websocket-client.cpp new file mode 100644 index 0000000..281480a --- /dev/null +++ b/src/websocket-client.cpp @@ -0,0 +1,193 @@ + +#pragma warning(disable : 4267) + +#define ASIO_STANDALONE +#define _WEBSOCKETPP_CPP11_TYPE_TRAITS_ +#define _WEBSOCKETPP_CPP11_RANDOM_DEVICE_ + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include "request-data.h" +#include "websocket-client.h" +#include "plugin-support.h" + +#include + +typedef websocketpp::client ws_client; + +struct WebSocketClientWrapper { + ws_client client; + websocketpp::connection_hdl connection; + std::unique_ptr asio_thread; + std::string last_received_message; + std::atomic is_connected{false}; + std::mutex mutex; + std::condition_variable cv; + + WebSocketClientWrapper() + { + try { + client.clear_access_channels(websocketpp::log::alevel::all); + client.clear_error_channels(websocketpp::log::elevel::all); + + client.init_asio(); + client.start_perpetual(); + + client.set_message_handler(std::bind(&WebSocketClientWrapper::on_message, + this, std::placeholders::_1, + std::placeholders::_2)); + + client.set_open_handler(std::bind(&WebSocketClientWrapper::on_open, this, + std::placeholders::_1)); + + client.set_close_handler(std::bind(&WebSocketClientWrapper::on_close, this, + std::placeholders::_1)); + + asio_thread = std::make_unique(&ws_client::run, &client); + } catch (const std::exception &e) { + // Log the error or handle it appropriately + throw std::runtime_error("Failed to initialize WebSocket client: " + + std::string(e.what())); + } + } + + ~WebSocketClientWrapper() + { + if (is_connected.load()) { + close(); + } + client.stop_perpetual(); + if (asio_thread && asio_thread->joinable()) { + asio_thread->join(); + } + } + + void on_message(websocketpp::connection_hdl, ws_client::message_ptr msg) + { + std::lock_guard lock(mutex); + last_received_message = msg->get_payload(); + cv.notify_one(); + } + + void on_open(websocketpp::connection_hdl hdl) + { + connection = hdl; + is_connected.store(true); + cv.notify_one(); + } + + void on_close(websocketpp::connection_hdl) + { + is_connected.store(false); + cv.notify_one(); + } + + bool connect(const std::string &uri) + { + websocketpp::lib::error_code ec; + ws_client::connection_ptr con = client.get_connection(uri, ec); + if (ec) { + return false; + } + + client.connect(con); + + std::unique_lock lock(mutex); + return cv.wait_for(lock, std::chrono::seconds(5), + [this] { return is_connected.load(); }); + } + + bool send(const std::string &message) + { + if (!is_connected.load()) { + return false; + } + + websocketpp::lib::error_code ec; + client.send(connection, message, websocketpp::frame::opcode::text, ec); + return !ec; + } + + bool receive(std::string &message, std::chrono::milliseconds timeout) + { + std::unique_lock lock(mutex); + if (cv.wait_for(lock, timeout, [this] { return !last_received_message.empty(); })) { + message = std::move(last_received_message); + last_received_message.clear(); + return true; + } + return false; + } + + void close() + { + if (is_connected.load()) { + websocketpp::lib::error_code ec; + client.close(connection, websocketpp::close::status::normal, + "Closing connection", ec); + if (ec) { + // Handle error + obs_log(LOG_WARNING, "Failed to close WebSocket connection: %s", + ec.message().c_str()); + } + } + } +}; + +struct request_data_handler_response +websocket_request_handler(url_source_request_data *request_data) +{ + request_data_handler_response response; + + try { + if (!request_data->ws_client_wrapper) { + request_data->ws_client_wrapper = new WebSocketClientWrapper(); + } + + if (!request_data->ws_connected) { + if (!request_data->ws_client_wrapper->connect(request_data->url)) { + throw std::runtime_error("Could not create WebSocket connection"); + } + request_data->ws_connected = true; + } + + nlohmann::json json; // json object or variables for inja + inja::Environment env; + prepare_inja_env(&env, request_data, response, json); + + if (response.status_code != URL_SOURCE_REQUEST_SUCCESS) { + return response; + } + + std::string message = env.render(request_data->body, json); + + if (!request_data->ws_client_wrapper->send(message)) { + throw std::runtime_error("Failed to send WebSocket message"); + } + + std::string received_message; + if (request_data->ws_client_wrapper->receive(received_message, + std::chrono::milliseconds(5000))) { + response.body = std::move(received_message); + response.status_code = URL_SOURCE_REQUEST_SUCCESS; + } else { + throw std::runtime_error("Timeout waiting for WebSocket response"); + } + } catch (const std::exception &e) { + response.status_code = URL_SOURCE_REQUEST_STANDARD_ERROR_CODE; + response.error_message = + "Error handling WebSocket request: " + std::string(e.what()); + } + + return response; +} diff --git a/src/websocket-client.h b/src/websocket-client.h new file mode 100644 index 0000000..6bb15b6 --- /dev/null +++ b/src/websocket-client.h @@ -0,0 +1,9 @@ +#ifndef WEBSOCKET_CLIENT_H +#define WEBSOCKET_CLIENT_H + +#include "request-data.h" + +struct request_data_handler_response +websocket_request_handler(url_source_request_data *request_data); + +#endif // WEBSOCKET_CLIENT_H