Skip to content

Commit

Permalink
add SocketContext
Browse files Browse the repository at this point in the history
  • Loading branch information
markaren committed Oct 20, 2024
1 parent 9a1c323 commit 9d58186
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 59 deletions.
32 changes: 32 additions & 0 deletions include/simple_socket/SocketContext.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

#ifndef SIMPLE_SOCKET_SOCKETCONTEXT_HPP
#define SIMPLE_SOCKET_SOCKETCONTEXT_HPP

#include "simple_socket/SimpleConnection.hpp"

#include <memory>
#include <string>

namespace simple_socket {

class SocketContext {
public:
SocketContext();

SocketContext(const SocketContext& other) = delete;
SocketContext& operator=(const SocketContext& other) = delete;
SocketContext(SocketContext&& other) = delete;
SocketContext& operator=(SocketContext&& other) = delete;

[[nodiscard]] virtual std::unique_ptr<SimpleConnection> connect(const std::string&) = 0;

virtual ~SocketContext();

private:
struct Impl;
std::unique_ptr<Impl> pimpl_;
};

}// namespace simple_socket

#endif//SIMPLE_SOCKET_SOCKETCONTEXT_HPP
19 changes: 4 additions & 15 deletions include/simple_socket/TCPSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,19 @@
#ifndef SIMPLE_SOCKET_TCPSOCKET_HPP
#define SIMPLE_SOCKET_TCPSOCKET_HPP

#include "simple_socket/SimpleConnection.hpp"
#include "simple_socket/SocketContext.hpp"

#include <cstdint>
#include <memory>
#include <string>

namespace simple_socket {

class TCPClientContext {
class TCPClientContext: public SocketContext {
public:
TCPClientContext();
[[nodiscard]] std::unique_ptr<SimpleConnection> connect(const std::string& ip, uint16_t port);

TCPClientContext(const TCPClientContext& other) = delete;
TCPClientContext& operator=(const TCPClientContext& other) = delete;
TCPClientContext(TCPClientContext&& other) = delete;
TCPClientContext& operator=(TCPClientContext&& other) = delete;

std::unique_ptr<SimpleConnection> connect(const std::string& ip, uint16_t port);

~TCPClientContext();

private:
struct Impl;
std::unique_ptr<Impl> pimpl_;
[[nodiscard]] std::unique_ptr<SimpleConnection> connect(const std::string& host) override;
};

class TCPServer {
Expand Down
14 changes: 3 additions & 11 deletions include/simple_socket/UnixDomainSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,16 @@
#ifndef SIMPLE_SOCKET_UNIXDOMAINSOCKET_HPP
#define SIMPLE_SOCKET_UNIXDOMAINSOCKET_HPP

#include "simple_socket/SimpleConnection.hpp"
#include "simple_socket/SocketContext.hpp"

#include <memory>
#include <string>

namespace simple_socket {

class UnixDomainClientContext {
class UnixDomainClientContext: public SocketContext {
public:
UnixDomainClientContext();

[[nodiscard]] std::unique_ptr<SimpleConnection> connect(const std::string& domain);

~UnixDomainClientContext();

private:
struct Impl;
std::unique_ptr<Impl> pimpl_;
[[nodiscard]] std::unique_ptr<SimpleConnection> connect(const std::string& domain) override;
};

class UnixDomainServer {
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(publicHeaderDir "${PROJECT_SOURCE_DIR}/include")
set(publicHeaders

"simple_socket/SimpleConnection.hpp"
"simple_socket/SocketContext.hpp"
"simple_socket/TCPSocket.hpp"
"simple_socket/UDPSocket.hpp"
"simple_socket/UnixDomainSocket.hpp"
Expand All @@ -29,6 +30,7 @@ set(privateHeaders

set(sources
"simple_socket/util/port_query.cpp"
"simple_socket/SocketContext.cpp"
"simple_socket/TCPSocket.cpp"
"simple_socket/UDPSocket.cpp"
"simple_socket/UnixDomainSocket.cpp"
Expand Down
18 changes: 18 additions & 0 deletions src/simple_socket/SocketContext.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

#include "simple_socket/SocketContext.hpp"

#include "WSASession.hpp"

using namespace simple_socket;

struct SocketContext::Impl {

#ifdef _WIN32
WSASession session{};
#endif
};

SocketContext::SocketContext()
: pimpl_(std::make_unique<Impl>()) {}

SocketContext::~SocketContext() = default;
32 changes: 22 additions & 10 deletions src/simple_socket/TCPSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,27 @@ namespace {
return sockfd;
}

}// namespace
std::pair<std::string, uint16_t> parseHostPort(const std::string& input) {
const size_t colonPos = input.find(':');
if (colonPos == std::string::npos) {
throw std::invalid_argument("Invalid input format. Expected 'host:port'.");
}

std::string host = input.substr(0, colonPos);
std::string portStr = input.substr(colonPos + 1);

struct TCPClientContext::Impl {
// Convert port string to uint16_t
uint16_t port;
try {
port = static_cast<uint16_t>(std::stoi(portStr));
} catch (const std::exception&) {
throw std::invalid_argument("Invalid port number.");
}

#ifdef _WIN32
WSASession session;
#endif
};
return std::make_pair(host, port);
}

}// namespace


struct TCPServer::Impl {
Expand Down Expand Up @@ -93,9 +105,6 @@ void TCPServer::close() {
TCPServer::~TCPServer() = default;


TCPClientContext::TCPClientContext()
: pimpl_(std::make_unique<Impl>()) {}

[[nodiscard]] std::unique_ptr<SimpleConnection> TCPClientContext::connect(const std::string& ip, uint16_t port) {

SOCKET sock = createSocket();
Expand All @@ -116,4 +125,7 @@ TCPClientContext::TCPClientContext()
return nullptr;
}

TCPClientContext::~TCPClientContext() = default;
std::unique_ptr<SimpleConnection> TCPClientContext::connect(const std::string& host) {
const auto [ip, port] = parseHostPort(host);
return connect(ip, port);
}
13 changes: 1 addition & 12 deletions src/simple_socket/UnixDomainSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace {

return sockfd;
}

void unlinkPath(const std::string& path) {
#ifdef _WIN32
DeleteFile(path.c_str());
Expand Down Expand Up @@ -91,16 +92,6 @@ std::unique_ptr<SimpleConnection> UnixDomainServer::accept() {

UnixDomainServer::~UnixDomainServer() = default;

struct UnixDomainClientContext::Impl {

#ifdef _WIN32
WSASession session;
#endif
};

UnixDomainClientContext::UnixDomainClientContext()
: pimpl_(std::make_unique<Impl>()) {}


std::unique_ptr<SimpleConnection> UnixDomainClientContext::connect(const std::string& domain) {

Expand All @@ -117,5 +108,3 @@ std::unique_ptr<SimpleConnection> UnixDomainClientContext::connect(const std::st

return nullptr;
}

UnixDomainClientContext::~UnixDomainClientContext() = default;
38 changes: 27 additions & 11 deletions tests/test_tcp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,48 @@ TEST_CASE("TCP read/write") {
TCPServer server(*port);

std::thread serverThread([&server] {
std::unique_ptr<SimpleConnection> conn;
REQUIRE_NOTHROW(conn = server.accept());
socketHandler(std::move(conn));
try {
while (true) {
std::unique_ptr<SimpleConnection> conn = server.accept();
socketHandler(std::move(conn));
}
} catch (std::exception&) {}
});

std::this_thread::sleep_for(std::chrono::milliseconds(100));

std::thread clientThread([port] {
TCPClientContext client;
const auto conn = client.connect("127.0.0.1", *port);
REQUIRE(conn);

auto msgHandler = [](SimpleConnection& conn) {
std::string message = generateMessage();
std::string expectedResponse = generateResponse(message);

conn->write(message);
conn.write(message);

std::vector<unsigned char> buffer(1024);
const auto bytesRead = conn->read(buffer);
const auto bytesRead = conn.read(buffer);
REQUIRE(bytesRead == expectedResponse.size());
std::string response(buffer.begin(), buffer.begin() + bytesRead);

CHECK(response == expectedResponse);
};

TCPClientContext clientCtx;
std::thread clientThread1([&clientCtx, port, msgHandler] {

const auto conn = clientCtx.connect("127.0.0.1", *port);
REQUIRE(conn);

msgHandler(*conn);
});

clientThread.join();
std::thread clientThread2([&clientCtx, port, msgHandler] {
const auto conn = clientCtx.connect("127.0.0.1:" + std::to_string(*port));
REQUIRE(conn);

msgHandler(*conn);
});

clientThread1.join();
clientThread2.join();

server.close();
serverThread.join();
Expand Down

0 comments on commit 9d58186

Please sign in to comment.