Skip to content

Commit

Permalink
make WSASession an implementation detail
Browse files Browse the repository at this point in the history
  • Loading branch information
markaren committed Aug 7, 2024
1 parent 382b45e commit ffd7f5a
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 93 deletions.
17 changes: 0 additions & 17 deletions include/WSASession.hpp

This file was deleted.

7 changes: 6 additions & 1 deletion src/AvailablePortQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
#include <algorithm>

int getAvailablePort(int startPort, int endPort, const std::vector<int>& excludePorts) {

#ifdef WIN32
WSASession session;
#endif

SOCKET sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd == SOCKET_ERROR) {
return -1;
Expand Down Expand Up @@ -35,4 +40,4 @@ int getAvailablePort(int startPort, int endPort, const std::vector<int>& exclude
closeSocket(sockfd);

return -1; // No available port found
}
}
9 changes: 8 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@

add_library(simple_socket TCPSocket.cpp UDPSocket.cpp WSASession.cpp AvailablePortQuery.cpp WebSocket.cpp)
set(sources
AvailablePortQuery.cpp
TCPSocket.cpp
UDPSocket.cpp
WebSocket.cpp
)

add_library(simple_socket ${sources})
target_compile_features(simple_socket PUBLIC "cxx_std_17")

if (WIN32)
Expand Down
3 changes: 2 additions & 1 deletion src/SocketIncludes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@


#ifdef _WIN32
#include <winsock2.h>
#include "WSASession.hpp"
#include <WinSock2.h>
#include <ws2tcpip.h>
#else
#include <arpa/inet.h>
Expand Down
38 changes: 21 additions & 17 deletions src/TCPSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

struct TCPSocket::Impl {

Impl(): sockfd(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) {
Impl(): sockfd_(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) {

if (sockfd == INVALID_SOCKET) {
if (sockfd_ == INVALID_SOCKET) {
throwSocketError("Failed to create socket");
}

const int optval = 1;
setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&optval), sizeof(optval));
setsockopt(sockfd_, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&optval), sizeof(optval));
}

bool connect(const std::string& ip, int port) const {
Expand All @@ -25,7 +25,7 @@ struct TCPSocket::Impl {
return false;
}

return ::connect(sockfd, reinterpret_cast<sockaddr*>(&serv_addr), sizeof(serv_addr)) >= 0;
return ::connect(sockfd_, reinterpret_cast<sockaddr*>(&serv_addr), sizeof(serv_addr)) >= 0;
}

void bind(int port) const {
Expand All @@ -35,15 +35,15 @@ struct TCPSocket::Impl {
serv_addr.sin_addr.s_addr = INADDR_ANY;
serv_addr.sin_port = htons(port);

if (::bind(sockfd, reinterpret_cast<sockaddr*>(&serv_addr), sizeof(serv_addr)) < 0) {
if (::bind(sockfd_, reinterpret_cast<sockaddr*>(&serv_addr), sizeof(serv_addr)) < 0) {

throwSocketError("Bind failed");
}
}

void listen(int backlog) const {

if (::listen(sockfd, backlog) < 0) {
if (::listen(sockfd_, backlog) < 0) {

throwSocketError("Listen failed");
}
Expand All @@ -53,7 +53,7 @@ struct TCPSocket::Impl {

sockaddr_in client_addr{};
socklen_t addrlen = sizeof(client_addr);
SOCKET new_sock = ::accept(sockfd, reinterpret_cast<sockaddr*>(&client_addr), &addrlen);
SOCKET new_sock = ::accept(sockfd_, reinterpret_cast<sockaddr*>(&client_addr), &addrlen);

if (new_sock == INVALID_SOCKET) {

Expand All @@ -69,9 +69,9 @@ struct TCPSocket::Impl {
bool read(unsigned char* buffer, size_t size, size_t* bytesRead) const {

#ifdef _WIN32
auto read = recv(sockfd, reinterpret_cast<char*>(buffer), static_cast<int>(size), 0);
auto read = recv(sockfd_, reinterpret_cast<char*>(buffer), static_cast<int>(size), 0);
#else
const auto read = ::read(sockfd, buffer, size);
const auto read = ::read(sockfd_, buffer, size);
#endif
if (bytesRead) *bytesRead = read;

Expand All @@ -84,9 +84,9 @@ struct TCPSocket::Impl {
while (totalBytesReceived < size) {
const auto remainingBytes = static_cast<int>(size) - totalBytesReceived;
#ifdef _WIN32
auto read = recv(sockfd, reinterpret_cast<char*>(buffer + totalBytesReceived), remainingBytes, 0);
auto read = recv(sockfd_, reinterpret_cast<char*>(buffer + totalBytesReceived), remainingBytes, 0);
#else
auto read = ::read(sockfd, buffer + totalBytesReceived, remainingBytes);
auto read = ::read(sockfd_, buffer + totalBytesReceived, remainingBytes);
#endif
if (read == SOCKET_ERROR || read == 0) {

Expand All @@ -102,23 +102,23 @@ struct TCPSocket::Impl {

const auto size = static_cast<int>(buffer.size());

return send(sockfd, buffer.data(), size, 0) != SOCKET_ERROR;
return send(sockfd_, buffer.data(), size, 0) != SOCKET_ERROR;
}

bool write(const std::vector<unsigned char>& buffer) const {

const auto size = static_cast<int>(buffer.size());

#ifdef _WIN32
return send(sockfd, reinterpret_cast<const char*>(buffer.data()), size, 0) != SOCKET_ERROR;
return send(sockfd_, reinterpret_cast<const char*>(buffer.data()), size, 0) != SOCKET_ERROR;
#else
return ::write(sockfd, buffer.data(), buffer.size()) != SOCKET_ERROR;
return ::write(sockfd_, buffer.data(), buffer.size()) != SOCKET_ERROR;
#endif
}

void close() const {

closeSocket(sockfd);
closeSocket(sockfd_);
}

~Impl() {
Expand All @@ -127,12 +127,16 @@ struct TCPSocket::Impl {
}

private:
SOCKET sockfd;
#ifdef WIN32
WSASession session_;
#endif

SOCKET sockfd_;

void assign(SOCKET new_sock) {
close();

sockfd = new_sock;
sockfd_ = new_sock;
}
};

Expand Down
21 changes: 12 additions & 9 deletions src/UDPSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
struct UDPSocket::Impl {

explicit Impl(int localPort)
: sockfd(socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) {
: sockfd_(socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) {

if (sockfd == INVALID_SOCKET) {
if (sockfd_ == INVALID_SOCKET) {

throwSocketError("Failed to create socket");
}
Expand All @@ -19,7 +19,7 @@ struct UDPSocket::Impl {
local.sin_addr.s_addr = INADDR_ANY;
local.sin_port = htons(localPort);

if (::bind(sockfd, (sockaddr*) &local, sizeof(local)) == SOCKET_ERROR) {
if (::bind(sockfd_, (sockaddr*) &local, sizeof(local)) == SOCKET_ERROR) {

throwSocketError("Bind failed");
}
Expand All @@ -34,7 +34,7 @@ struct UDPSocket::Impl {
return false;
}

return sendto(sockfd, data.c_str(), data.size(), 0, reinterpret_cast<sockaddr*>(&to), sizeof(to)) != SOCKET_ERROR;
return sendto(sockfd_, data.c_str(), data.size(), 0, reinterpret_cast<sockaddr*>(&to), sizeof(to)) != SOCKET_ERROR;
}

bool sendTo(const std::string& address, uint16_t port, const std::vector<unsigned char>& data) const {
Expand All @@ -46,7 +46,7 @@ struct UDPSocket::Impl {
return false;
}

return sendto(sockfd, reinterpret_cast<const char*>(data.data()), data.size(), 0, reinterpret_cast<sockaddr*>(&to), sizeof(to)) != SOCKET_ERROR;
return sendto(sockfd_, reinterpret_cast<const char*>(data.data()), data.size(), 0, reinterpret_cast<sockaddr*>(&to), sizeof(to)) != SOCKET_ERROR;
}

int recvFrom(const std::string& address, uint16_t port, std::vector<unsigned char>& buffer) const {
Expand All @@ -60,7 +60,7 @@ struct UDPSocket::Impl {
}
socklen_t fromLength = sizeof(from);

const auto receive = recvfrom(sockfd, reinterpret_cast<char*>(buffer.data()), buffer.size(), 0, reinterpret_cast<sockaddr*>(&from), &fromLength);
const auto receive = recvfrom(sockfd_, reinterpret_cast<char*>(buffer.data()), buffer.size(), 0, reinterpret_cast<sockaddr*>(&from), &fromLength);
if (receive == SOCKET_ERROR) {
return -1;
}
Expand All @@ -81,7 +81,7 @@ struct UDPSocket::Impl {

static std::vector<unsigned char> buffer(MAX_UDP_PACKET_SIZE);

const auto receive = recvfrom(sockfd, reinterpret_cast<char*>(buffer.data()), buffer.size(), 0, reinterpret_cast<sockaddr*>(&from), &fromLength);
const auto receive = recvfrom(sockfd_, reinterpret_cast<char*>(buffer.data()), buffer.size(), 0, reinterpret_cast<sockaddr*>(&from), &fromLength);
if (receive == SOCKET_ERROR) {

return "";
Expand All @@ -92,7 +92,7 @@ struct UDPSocket::Impl {

void close() const {

closeSocket(sockfd);
closeSocket(sockfd_);
}

~Impl() {
Expand All @@ -101,7 +101,10 @@ struct UDPSocket::Impl {
}

private:
SOCKET sockfd;
#ifdef WIN32
WSASession session_;
#endif
SOCKET sockfd_;
};


Expand Down
29 changes: 0 additions & 29 deletions src/WSASession.cpp

This file was deleted.

39 changes: 39 additions & 0 deletions src/WSASession.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

#ifndef SIMPLE_SOCKET_WSASESSION_HPP
#define SIMPLE_SOCKET_WSASESSION_HPP

#include <WinSock2.h>
#include <memory>
#include <mutex>
#include <system_error>

namespace {

int ref_count_ = 0;
std::mutex mutex_;

}// namespace

class WSASession {
public:
WSASession() {
std::lock_guard<std::mutex> lock(mutex_);
if (ref_count_ == 0) {
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
throw std::system_error(WSAGetLastError(), std::system_category(), "Failed to initialize winsock");
}
}
++ref_count_;
}

~WSASession() {
std::lock_guard<std::mutex> lock(mutex_);
if (--ref_count_ == 0) {
WSACleanup();
}
}
};


#endif//SIMPLE_SOCKET_WSASESSION_HPP
3 changes: 0 additions & 3 deletions tests/integration/run_tcp_client.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@

#include "TCPSocket.hpp"
#include "WSASession.hpp"

#include <iostream>
#include <vector>

int main() {

WSASession session;

TCPClient client;
if (client.connect("127.0.0.1", 8080)) {

Expand Down
3 changes: 0 additions & 3 deletions tests/integration/run_tcp_server.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "TCPSocket.hpp"
#include "WSASession.hpp"

#include <atomic>
#include <iostream>
Expand All @@ -18,8 +17,6 @@ void socketHandler(std::unique_ptr<TCPConnection> conn) {

int main() {

WSASession session;

TCPServer server(8080);

std::atomic_bool stop = false;
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/run_ws.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@

#include "WebSocket.hpp"
#include "WSASession.hpp"

#include <iostream>

int main() {

WSASession session;

WebSocket ws(8081);
ws.onOpen = [](auto) {
std::cout << "onOpen" << std::endl;
Expand Down
Loading

0 comments on commit ffd7f5a

Please sign in to comment.