Skip to content

Commit

Permalink
working solution
Browse files Browse the repository at this point in the history
  • Loading branch information
markaren committed Sep 18, 2024
1 parent 6b2cd75 commit b9799e3
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 14 deletions.
17 changes: 13 additions & 4 deletions include/TCPSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ namespace simple_socket {

~TCPSocket() override;

private:
friend class TCPClient;
friend class TCPServer;

protected:
struct Impl;
std::unique_ptr<Impl> pimpl_;
};
Expand All @@ -69,6 +66,18 @@ namespace simple_socket {
std::unique_ptr<TCPConnection> accept();
};

class UnixDomainClient: public TCPSocket {
public:
bool connect(const std::string& domain);
};

class UnixDomainServer: public TCPSocket {
public:
explicit UnixDomainServer(const std::string& domain, int backlog = 1);

std::unique_ptr<TCPConnection> accept();
};

}// namespace simple_socket

#endif// SIMPLE_SOCKET_TCPSOCKET_HPP
98 changes: 88 additions & 10 deletions src/TCPSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,31 @@ namespace simple_socket {

struct TCPSocket::Impl {

Impl(): sockfd_(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) {
Impl(): sockfd_(INVALID_SOCKET) {

// if (sockfd_ == INVALID_SOCKET) {
// throwSocketError("Failed to create socket");
// }

// const int optval = 1;
// setsockopt(sockfd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&optval), sizeof(optval));
}

void setupSocket(int domain, int protocol) {
sockfd_ = socket(domain, SOCK_STREAM, protocol);
if (sockfd_ == INVALID_SOCKET) {
throwSocketError("Failed to create socket");
}

const int optval = 1;
setsockopt(sockfd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&optval), sizeof(optval));
if (domain == AF_INET) {
const int optval = 1;
setsockopt(sockfd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&optval), sizeof(optval));
}
}

bool connect(const std::string& ip, int port) const {
bool connect(const std::string& ip, int port) {

setupSocket(AF_INET, IPPROTO_TCP);

sockaddr_in serv_addr{};
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(port);
Expand All @@ -29,7 +43,20 @@ namespace simple_socket {
return ::connect(sockfd_, reinterpret_cast<sockaddr*>(&serv_addr), sizeof(serv_addr)) >= 0;
}

void bind(int port) const {
bool connect(const std::string& domain) {

setupSocket(AF_UNIX, 0);

sockaddr_un addr{};
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, domain.c_str(), sizeof(addr.sun_path) - 1);

return ::connect(sockfd_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) >= 0;
}

void bind(int port) {

setupSocket(AF_INET, IPPROTO_TCP);

sockaddr_in serv_addr{};
serv_addr.sin_family = AF_INET;
Expand All @@ -42,6 +69,23 @@ namespace simple_socket {
}
}

void bind(const std::string& domain) {

setupSocket(AF_UNIX, 0);

unlinkPath(domain);

sockaddr_un addr{};
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, domain.c_str(), sizeof(addr.sun_path) - 1);

if (::bind(sockfd_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) < 0) {

throwSocketError("Bind failed");
}
}

void listen(int backlog) const {

if (::listen(sockfd_, backlog) < 0) {
Expand All @@ -50,7 +94,7 @@ namespace simple_socket {
}
}

[[nodiscard]] std::unique_ptr<TCPConnection> accept() const {
[[nodiscard]] std::unique_ptr<TCPConnection> acceptTCP() const {

sockaddr_in client_addr{};
socklen_t addrlen = sizeof(client_addr);
Expand All @@ -67,6 +111,21 @@ namespace simple_socket {
return conn;
}

[[nodiscard]] std::unique_ptr<TCPConnection> acceptUnix() const {

SOCKET new_sock = ::accept(sockfd_, nullptr, nullptr);

if (new_sock == INVALID_SOCKET) {

throwSocketError("Accept failed");
}

auto conn = std::make_unique<TCPSocket>();
conn->pimpl_->assign(new_sock);

return conn;
}

bool read(unsigned char* buffer, size_t size, size_t* bytesRead) const {

#ifdef _WIN32
Expand Down Expand Up @@ -186,19 +245,38 @@ namespace simple_socket {

TCPSocket::~TCPSocket() = default;


bool TCPClient::connect(const std::string& ip, int port) {

return pimpl_->connect(ip, port);
}


TCPServer::TCPServer(int port, int backlog) {

pimpl_->bind(port);
pimpl_->listen(backlog);
}

std::unique_ptr<TCPConnection> TCPServer::accept() {

return pimpl_->accept();
return pimpl_->acceptTCP();
}

TCPServer::TCPServer(int port, int backlog) {

pimpl_->bind(port);
bool UnixDomainClient::connect(const std::string& domain) {

return pimpl_->connect(domain);
}

std::unique_ptr<TCPConnection> UnixDomainServer::accept() {

return pimpl_->acceptUnix();
}

UnixDomainServer::UnixDomainServer(const std::string& domain, int backlog) {

pimpl_->bind(domain);
pimpl_->listen(backlog);
}

Expand Down
10 changes: 10 additions & 0 deletions src/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
#ifdef _WIN32
#include "WSASession.hpp"
#include <WinSock2.h>
#include <afunix.h>
#include <ws2tcpip.h>
#else
#include <arpa/inet.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
using SOCKET = int;
#define INVALID_SOCKET (SOCKET)(~0)
Expand Down Expand Up @@ -43,6 +45,14 @@ namespace simple_socket {
}
}

inline void unlinkPath(const std::string& path) {
#ifdef _WIN32
DeleteFile(path.c_str());
#else
unlink(path.c_str());
#endif
}

}// namespace simple_socket

#endif//SIMPLE_SOCKET_COMMON_HPP
5 changes: 5 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ add_executable(test_udp test_udp.cpp)
add_test(NAME test_udp COMMAND test_udp)
target_link_libraries(test_udp PRIVATE simple_socket Catch2::Catch2WithMain)

add_executable(test_un test_un.cpp)
add_test(NAME test_un COMMAND test_un)
target_link_libraries(test_un PRIVATE simple_socket Catch2::Catch2WithMain)

if (UNIX)
target_link_libraries(test_tcp PRIVATE pthread)
target_link_libraries(test_udp PRIVATE pthread)
target_link_libraries(test_un PRIVATE pthread)

endif ()

Expand Down
82 changes: 82 additions & 0 deletions tests/test_un.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

#include "TCPSocket.hpp"
#include "port_query.hpp"

#include <thread>
#include <vector>

#include <catch2/catch_test_macros.hpp>

using namespace simple_socket;

namespace {

std::string generateMessage() {

return "Per";
}

std::string generateResponse(const std::string& msg) {

return "Hello " + msg + "!";
}

void socketHandler(std::unique_ptr<TCPConnection> conn) {

std::vector<unsigned char> buffer(1024);
const auto bytesRead = conn->read(buffer);

std::string expectedMessage = generateMessage();
REQUIRE(bytesRead == expectedMessage.size());

std::string msg(buffer.begin(), buffer.begin() + bytesRead);
REQUIRE(msg == expectedMessage);

REQUIRE(conn->write(generateResponse(msg)));
}

}// namespace

TEST_CASE("UNIX Domain Socket read/write") {

#ifdef _WIN32
const std::string domain{"afunix_socket"};
#else
const std::string domain{"/tmp/unix_socket"};
#endif

UnixDomainServer server(domain);
UnixDomainClient client;

std::thread serverThread([&server] {
std::unique_ptr<TCPConnection> conn;
REQUIRE_NOTHROW(conn = server.accept());
socketHandler(std::move(conn));
});

std::this_thread::sleep_for(std::chrono::milliseconds(100));

std::thread clientThread([&client, domain] {
REQUIRE(client.connect(domain));

std::string message = generateMessage();
std::string expectedResponse = generateResponse(message);

client.write(message);

std::vector<unsigned char> buffer(1024);
const auto bytesRead = client.read(buffer);
REQUIRE(bytesRead == expectedResponse.size());
std::string response(buffer.begin(), buffer.begin() + static_cast<int>(bytesRead));

CHECK(response == expectedResponse);
});

clientThread.join();
client.close();

REQUIRE(!server.write(""));

server.close();
serverThread.join();
}

0 comments on commit b9799e3

Please sign in to comment.