Skip to content

Commit

Permalink
Add WebSocket client (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
markaren authored Sep 28, 2024
1 parent 850c229 commit c47bf51
Show file tree
Hide file tree
Showing 9 changed files with 563 additions and 150 deletions.
20 changes: 20 additions & 0 deletions include/simple_socket/WebSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,26 @@ namespace simple_socket {
std::unique_ptr<Impl> pimpl_;
};

class WebSocketClient {

public:
std::function<void(WebSocketConnection*)> onOpen;
std::function<void(WebSocketConnection*)> onClose;
std::function<void(WebSocketConnection*, const std::string&)> onMessage;

WebSocketClient();

void connect(const std::string& host, uint16_t port);

void close();

~WebSocketClient();

private:
struct Impl;
std::unique_ptr<Impl> pimpl_;
};

}// namespace simple_socket

#endif//SIMPLE_SOCKET_WEBSOCKET_HPP
3 changes: 3 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ set(privateHeaders
"simple_socket/Socket.hpp"
"simple_socket/util/uuid.hpp"

"simple_socket/ws/WebSocketConnection.hpp"
"simple_socket/ws/WebSocketHandshake.hpp"
"simple_socket/ws/WebSocketHandshakeKeyGen.hpp"
)

set(sources
Expand All @@ -32,6 +34,7 @@ set(sources
"simple_socket/UnixDomainSocket.cpp"

"simple_socket/ws/WebSocket.cpp"
"simple_socket/ws/WebSocketClient.cpp"

"simple_socket/modbus/HoldingRegister.cpp"
"simple_socket/modbus/ModbusClient.cpp"
Expand Down
164 changes: 14 additions & 150 deletions src/simple_socket/ws/WebSocket.cpp
Original file line number Diff line number Diff line change
@@ -1,60 +1,29 @@

#include "simple_socket/WebSocket.hpp"

#include "simple_socket/TCPSocket.hpp"
#include "simple_socket/socket_common.hpp"

#include "simple_socket/util/uuid.hpp"

#include "simple_socket/ws/WebSocketConnection.hpp"
#include "simple_socket/ws/WebSocketHandshake.hpp"

#include <algorithm>
#include <atomic>
#include <iostream>
#include <sstream>
#include <thread>
#include <utility>
#include <vector>

using namespace simple_socket;

namespace {

std::vector<uint8_t> createFrame(const std::string& message) {
std::vector<uint8_t> frame;
frame.push_back(0x81);// FIN, text frame

if (message.size() <= 125) {
frame.push_back(static_cast<uint8_t>(message.size()));
} else if (message.size() <= 65535) {
frame.push_back(126);
frame.push_back((message.size() >> 8) & 0xFF);
frame.push_back(message.size() & 0xFF);
} else {
frame.push_back(127);
for (int i = 7; i >= 0; --i) {
frame.push_back((message.size() >> (i * 8)) & 0xFF);
}
}

frame.insert(frame.end(), message.begin(), message.end());
return frame;
}
}// namespace

class WebSocketConnectionImpl: public WebSocketConnection {

public:
WebSocketConnectionImpl(WebSocket* socket, std::unique_ptr<SimpleConnection> conn)
: socket(socket), conn(std::move(conn)) {

handshake();

thread = std::thread([this] {
listen();
});
}

void handshake() const {
void handshake(SimpleConnection& conn) {
std::vector<unsigned char> buffer(1024);
const auto bytesReceived = conn->read(buffer);
const auto bytesReceived = conn.read(buffer);
if (bytesReceived == -1) {
throwSocketError("Failed to read handshake request from client.");
}
Expand All @@ -78,116 +47,11 @@ class WebSocketConnectionImpl: public WebSocketConnection {
<< "Sec-WebSocket-Accept: " << secWebSocketAccept << "\r\n\r\n";

const std::string responseStr = response.str();
if (!conn->write(responseStr)) {
if (!conn.write(responseStr)) {
throwSocketError("Failed to send handshake response");
}
}

void send(const std::string& message) override {
const auto frame = createFrame(message);
conn->write(frame);
}

void listen() {
std::vector<unsigned char> buffer(1024);
while (!closed) {

const auto recv = conn->read(buffer);
if (recv == -1) {
break;
}

std::vector<uint8_t> frame{buffer.begin(), buffer.begin() + recv};

if (frame.size() < 2) return;

uint8_t opcode = frame[0] & 0x0F;
// bool isFinal = (frame[0] & 0x80) != 0;
bool isMasked = (frame[1] & 0x80) != 0;
uint64_t payloadLen = frame[1] & 0x7F;

size_t pos = 2;
if (payloadLen == 126) {
payloadLen = (frame[2] << 8) | frame[3];
pos += 2;
} else if (payloadLen == 127) {
payloadLen = 0;
for (int i = 0; i < 8; ++i) {
payloadLen = (payloadLen << 8) | frame[2 + i];
}
pos += 8;
}

std::vector<uint8_t> mask(4);
if (isMasked) {
for (int i = 0; i < 4; ++i) {
mask[i] = frame[pos++];
}
}

std::vector payload(frame.begin() + pos, frame.begin() + pos + payloadLen);
if (isMasked) {
for (size_t i = 0; i < payload.size(); ++i) {
payload[i] ^= mask[i % 4];
}
}

switch (opcode) {
case 0x1:// Text frame
{
std::string message(payload.begin(), payload.end());
socket->onMessage(this, message);
} break;
case 0x8:// Close frame
close(false);

break;
case 0x9:// Ping frame
std::cout << "Received ping frame" << std::endl;
{
std::vector<uint8_t> pongFrame = {0x8A};// FIN, Pong frame
pongFrame.push_back(static_cast<uint8_t>(payload.size()));
pongFrame.insert(pongFrame.end(), payload.begin(), payload.end());
conn->write(pongFrame);
}
break;
case 0xA:// Pong frame
std::cout << "Received pong frame" << std::endl;
break;
default:
std::cerr << "Unsupported opcode: " << static_cast<int>(opcode) << std::endl;
break;
}
}
}

void close(bool self) {
if (!closed) {
closed = true;

socket->onClose(this);

if (!self) {
std::vector<uint8_t> closeFrame = {0x88};// FIN, Close frame
closeFrame.push_back(0);
conn->write(closeFrame);
}
conn->close();
}
}

~WebSocketConnectionImpl() override {
close(true);
if (thread.joinable()) {
thread.join();
}
}

std::atomic_bool closed{false};
WebSocket* socket;
std::unique_ptr<SimpleConnection> conn;
std::thread thread;
};
}// namespace

struct WebSocket::Impl {

Expand All @@ -201,17 +65,17 @@ struct WebSocket::Impl {
while (!stop_) {

try {
auto ws = std::make_unique<WebSocketConnectionImpl>(scope, socket.accept());
scope->onOpen(ws.get());
auto ws = std::make_unique<WebSocketConnectionImpl>(WebSocketCallbaks{scope->onOpen, scope->onClose, scope->onMessage}, socket.accept());
ws->run(handshake);
connections.emplace_back(std::move(ws));
} catch (std::exception& ex) {
} catch (std::exception&) {
// std::cerr << ex.what() << std::endl;
}

//cleanup connections
for (auto it = connections.begin(); it != connections.end();) {

if ((*it)->closed) {
if ((*it)->closed()) {
it = connections.erase(it);
} else {
++it;
Expand All @@ -222,8 +86,8 @@ struct WebSocket::Impl {

void start() {
thread = std::thread([this] {
run();
});
run();
});
}

void stop() {
Expand Down
94 changes: 94 additions & 0 deletions src/simple_socket/ws/WebSocketClient.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

#include "simple_socket/WebSocket.hpp"

#include "simple_socket/TCPSocket.hpp"
#include "simple_socket/socket_common.hpp"

#include "simple_socket/ws/WebSocketConnection.hpp"
#include "simple_socket/ws/WebSocketHandshakeKeyGen.hpp"

#include <sstream>

using namespace simple_socket;

namespace {
void performHandshake(SimpleConnection& conn, const std::string& host, uint16_t port) {
const std::string key;
char output[28] = {};// 28 chars for base64-encoded output
WebSocketHandshakeKeyGen::generate(key, output);

std::ostringstream request;
request << "GET "
<< "/ws"
<< " HTTP/1.1\r\n"
<< "Host: " << host << ":" << port << "\r\n"
<< "Upgrade: websocket\r\n"
<< "Connection: Upgrade\r\n"
<< "Sec-WebSocket-Key: " << key << "\r\n"
<< "Sec-WebSocket-Version: 13\r\n\r\n";

const std::string requestStr = request.str();
if (!conn.write(requestStr)) {
throwSocketError("Failed to send handshake request");
}

std::vector<uint8_t> buffer(1024);
const auto bytesReceived = conn.read(buffer);
if (bytesReceived == -1) {
throwSocketError("Failed to read handshake response from server.");
}

const std::string response(buffer.begin(), buffer.begin() + bytesReceived);
if (response.find(" 101 ") == std::string::npos) {
throwSocketError("Handshake failed with the server.");
}
}
}// namespace

struct WebSocketClient::Impl {

std::unique_ptr<WebSocketConnectionImpl> conn;

explicit Impl(WebSocketClient* scope)
: scope_(scope) {}

void connect(const std::string& host, uint16_t port) {

auto c = ctx_.connect(host, port);

conn = std::make_unique<WebSocketConnectionImpl>(WebSocketCallbaks{scope_->onOpen, scope_->onClose, scope_->onMessage}, std::move(c));
conn->run([host, port](SimpleConnection& conn) {
performHandshake(conn, host, port);
});
}

void send(const std::string& message) {
conn->send(message);
}

void close() {
conn->close(true);
}

~Impl() {
close();
}

private:
TCPClientContext ctx_;
WebSocketClient* scope_;
};

WebSocketClient::WebSocketClient()
: pimpl_(std::make_unique<Impl>(this)) {}


void WebSocketClient::connect(const std::string& host, uint16_t port) {
pimpl_->connect(host, port);
}

void WebSocketClient::close() {
pimpl_->close();
}

WebSocketClient::~WebSocketClient() = default;
Loading

0 comments on commit c47bf51

Please sign in to comment.