// clockkit got this from // https://raw.githubusercontent.com/Ybalrid/kissnet/master/kissnet.hpp // and added a few corrections. /* * MIT License * * Copyright (c) 2018-2020 Arthur Brainville (Ybalrid) and with the help of * Comunity Contributors! * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * INTRODUCTION * ============ * * Kissnet is a simple C++17 layer around the raw OS provided socket API to be * used on IP networks with the TCP and UDP protocols. * * Kissnet is not a networking framework, and it will not process your data or * assist you in any way. Kissnet's only goal is to provide a simple API to send * and receive bytes, * without having to play around with a bunch of structure, file descriptors, * handles and pointers given to a C-style API. The other goal of kissnet is to * provide an API that will works in a cross platform setting. * * Kissnet will automatically manage the eventual startup/shutdown of the * library needed to perform socket operations on a particular platform. (e.g. * the Windows Socket API on MS-Windows. * * Kissnet leverages (and expect you to do so), multiple features from C++17, * including: std::byte, if constexpr, structured bindings, if-initializer and * template parameter type deduction. * * The library is structured across 4 exposed data types: * * - buffer : a static array of std::byte implemented via std::array. * This is what you should use to hold raw data you are getting from a socket, * before extracting what you need from the bytes * - port_t : a 16 bit unsigned number. Represent a network port number * - endpoint : a structure that represent a location where you need to connect * to. Contains a hostname (as std::string) and a port number (as port_t) * - socket : a templated class that represents an ipv4 or ipv6 socket. * Protocol is either TCP or UDP * * Kissnet does error handling in 2 ways: * * 1: * When an operation can generate an error that the user should handle by hand * anyway, a tuple containing the expected type returned, and an object that * represent the status of what happens is returned. * * For example, socket send/receive operation can discover that the connection * was closed, or was shut down properly. It could also be the fact that a * socket was configured "non blocking" and would have blocked in this * situation. On both occasion, these methods will return the fact that 0 bytes * came across as the transaction size, and the status will indicate either an * error (socket no longer valid), or an actual status message (connection * closed, socket would have blocked) * * These status objects will behave like a const bool that equals "false" when * an error occurred, and "true" when it's just a status notification * * 2: * Fatal errors are by default handled by throwing a runtime_error exception. * But, for many reasons, you may want to * not use exceptions entirely. * * kissnet give you some facilities to get fatal errors information back, and * to choose how to handle it. Kissnet give you a few levers you can use: * * - You can deactivate the exception support by #defining KISSNET_NO_EXCEP * before #including kissnet.hpp. Instead, kissnet will use a function based * error handler * - By default, the error handler prints to stderr the error message, and * abort the program * - kissnet::error::callback is a function pointer that gets a string, and a * context pointer. The string is the error message, and the context pointer * what ever you gave kissnet for the occasion. This is a global pointer that * you can set as you want. This will override the "print to stderr" behavior * at fatal error time. * - kissnet::error::ctx is a void*, this will be passed to your error handler * as a "context" pointer. If you need your handler to write to a log, * or to turn on the HTCPCP enabled teapot on John's desk, you can. * - kissnet::abortOnFatalError is a boolean that will control the call to * abort(). This is independent to the fact that you did set or not an error * callback. please note that any object involved with the operation that * triggered the fatal error is probably in an invalid state, and probably * deserve to be thrown away. */ #ifndef KISS_NET #define KISS_NET ///Define this to not use exceptions #ifndef KISSNET_NO_EXCEP #define kissnet_fatal_error(STR) throw std::runtime_error(STR) #else #define kissnet_fatal_error(STR) kissnet::error::handle(STR); #endif #include #include #include #include #include #include #include #include #include #include // clockkit #ifdef _WIN32 #define _WINSOCK_DEPRECATED_NO_WARNINGS #define WIN32_LEAN_AND_MEAN #ifndef NOMINMAX #define NOMINMAX #endif //endif nominmax #include #include #include using ioctl_setting = u_long; using buffsize_t = int; #define AI_ADDRCONFIG 0x00000400 #ifndef SHUT_RDWR #define SHUT_RDWR SD_BOTH #endif // taken from: https://github.com/rxi/dyad/blob/915ae4939529b9aaaf6ebfd2f65c6cff45fc0eac/src/dyad.c#L58 inline const char* inet_ntop(int af, const void* src, char* dst, socklen_t size) { union { struct sockaddr sa; struct sockaddr_in sai; struct sockaddr_in6 sai6; } addr; int res; memset(&addr, 0, sizeof(addr)); addr.sa.sa_family = af; if (af == AF_INET6) { memcpy(&addr.sai6.sin6_addr, src, sizeof(addr.sai6.sin6_addr)); } else { memcpy(&addr.sai.sin_addr, src, sizeof(addr.sai.sin_addr)); } res = WSAAddressToStringA(&addr.sa, sizeof(addr), 0, dst, reinterpret_cast(&size)); if (res != 0) return NULL; return dst; } //Handle WinSock2/Windows Socket API initialization and cleanup #pragma comment(lib, "Ws2_32.lib") namespace kissnet { namespace win32_specific { ///Forward declare the object that will permit to manage the WSAStartup/Cleanup automatically struct WSA; ///Enclose the global pointer in this namespace. Only use this inside a shared_ptr namespace internal_state { static WSA* global_WSA = nullptr; } ///WSA object. Only to be constructed with std::make_shared() struct WSA : std::enable_shared_from_this { //For safety, only initialize Windows Socket API once, and delete it once ///Prevent copy construct WSA(const WSA&) = delete; ///Prevent copy assignment WSA& operator=(const WSA&) = delete; ///Prevent moving WSA(WSA&&) = delete; ///Prevent move assignment WSA& operator=(WSA&&) = delete; ///data storage WSADATA wsa_data; ///Startup WSA() : wsa_data {} { if (const auto status = WSAStartup(MAKEWORD(2, 2), &wsa_data); status != 0) { std::string error_message; switch (status) // https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-wsastartup#return-value { default: error_message = "Unknown error happened."; break; case WSASYSNOTREADY: error_message = "The underlying network subsystem is not ready for network communication."; break; case WSAVERNOTSUPPORTED: //unlikely, we specify 2.2! error_message = " The version of Windows Sockets support requested " "(2.2)" //we know here the version was 2.2, add that to the error message copied from MSDN " is not provided by this particular Windows Sockets implementation. "; break; case WSAEINPROGRESS: error_message = "A blocking Windows Sockets 1.1 operation is in progress."; break; case WSAEPROCLIM: error_message = "A limit on the number of tasks supported by the Windows Sockets implementation has been reached."; break; case WSAEFAULT: //unlikely, if this ctor is running, wsa_data is part of this object's "stack" data error_message = "The lpWSAData parameter is not a valid pointer."; break; } kissnet_fatal_error(error_message); } #ifdef KISSNET_WSA_DEBUG std::cerr << "Initialized Windows Socket API\n"; #endif } ///Cleanup ~WSA() { WSACleanup(); internal_state::global_WSA = nullptr; #ifdef KISSNET_WSA_DEBUG std::cerr << "Cleanup Windows Socket API\n"; #endif } ///get the shared pointer std::shared_ptr getPtr() { return shared_from_this(); } }; ///Get-or-create the global pointer inline std::shared_ptr getWSA() { //If it has been created already: if (internal_state::global_WSA) return internal_state::global_WSA->getPtr(); //fetch the smart pointer from the naked pointer //Create in wsa auto wsa = std::make_shared(); //Save the raw address in the global state internal_state::global_WSA = wsa.get(); //Return the smart pointer return wsa; } } #define KISSNET_OS_SPECIFIC_PAYLOAD_NAME wsa_ptr #define KISSNET_OS_SPECIFIC std::shared_ptr KISSNET_OS_SPECIFIC_PAYLOAD_NAME #define KISSNET_OS_INIT KISSNET_OS_SPECIFIC_PAYLOAD_NAME = kissnet::win32_specific::getWSA() ///Return the last error code inline int get_error_code() { const auto error = WSAGetLastError(); //We need to posixify the values that we are actually using inside this header. switch (error) { case WSAEWOULDBLOCK: return EWOULDBLOCK; case WSAEBADF: return EBADF; case WSAEINTR: return EINTR; default: return error; } } } #else //UNIX platform #include #include #include #include #include #include #include #include #include #include using ioctl_setting = int; using buffsize_t = size_t; //To get consistent socket API between Windows and Linux: static const int INVALID_SOCKET = -1; static const int SOCKET_ERROR = -1; using SOCKET = int; using SOCKADDR_IN = sockaddr_in; using SOCKADDR = sockaddr; using IN_ADDR = in_addr; //Wrap them in their WIN32 names inline int closesocket(SOCKET in) { return close(in); } template inline int ioctlsocket(int fd, int request, Params&&... params) { return ioctl(fd, request, params...); } #define KISSNET_OS_SPECIFIC_PAYLOAD_NAME dummy #define KISSNET_OS_SPECIFIC char dummy #define KISSNET_OS_INIT dummy = 42; namespace unix_specific { } inline int get_error_code() { return errno; } #endif //ifdef WIN32 #ifdef KISSNET_USE_OPENSSL #include #include #include #include #endif //Kissnet use OpenSSL #ifndef SOL_TCP #define SOL_TCP IPPROTO_TCP #endif ///Main namespace of kissnet namespace kissnet { ///Exception-less error handling infrastructure namespace error { static void (*callback)(const std::string&, void* ctx) = nullptr; static void* ctx = nullptr; static bool abortOnFatalError = true; inline void handle(const std::string& str) { //if the error::callback function has been provided, call that if (callback) { callback(str, ctx); } //Print error into the standard error output else { fputs(str.c_str(), stderr); } //If the error abort hasn't been deactivated if (abortOnFatalError) { abort(); } } } ///low level protocol used, between TCP\TCP_SSL and UDP enum class protocol { tcp, tcp_ssl, udp }; ///Address information structs struct addr_collection { sockaddr_storage adrinf = {0}; socklen_t sock_size = 0; }; ///File descriptor set types static constexpr int fds_read = 0x1; static constexpr int fds_write = 0x2; static constexpr int fds_except = 0x4; ///buffer is an array of std::byte template using buffer = std::array; ///port_t is the port using port_t = uint16_t; ///An endpoint is where the network will connect to (address and port) struct endpoint { ///The address to connect to std::string address {}; ///The port to connect to port_t port {}; ///Default constructor, the endpoint is not valid at that point, but you can set the address/port manually endpoint() = default; ///Basically create the endpoint with what you give it endpoint(std::string addr, port_t prt) : address { std::move(addr) }, port { prt } { } static bool is_valid_port_number(unsigned long n) { return n < 1 << 16; } ///Construct the endpoint from "address:port" endpoint(std::string addr) { const auto separator = addr.find_last_of(':'); //Check if input wasn't missformed if (separator == std::string::npos) kissnet_fatal_error("string is not of address:port form"); if (separator == addr.size() - 1) kissnet_fatal_error("string has ':' as last character. Expected port number here"); //Isolate address address = addr.substr(0, separator); //Read from string as unsigned const auto parsed_port = strtoul(addr.substr(separator + 1).c_str(), nullptr, 10); //In all other cases, port was always given as a port_t type, strongly preventing it to be a number outside of the [0; 65535] range. Here it's not the case. //To detect errors early, check it here : if (!is_valid_port_number(parsed_port)) kissnet_fatal_error("Invalid port number " + std::to_string(parsed_port)); //Store it port = static_cast(parsed_port); } ///Construct an endpoint from a SOCKADDR endpoint(SOCKADDR* addr) { switch (addr->sa_family) { case AF_INET: { auto ip_addr = (SOCKADDR_IN*)(addr); address = inet_ntoa(ip_addr->sin_addr); port = ntohs(ip_addr->sin_port); } break; case AF_INET6: { auto ip_addr = (sockaddr_in6*)(addr); char buffer[INET6_ADDRSTRLEN]; address = inet_ntop(AF_INET6, &(ip_addr->sin6_addr), buffer, INET6_ADDRSTRLEN); port = ntohs(ip_addr->sin6_port); } break; default: { kissnet_fatal_error("Trying to construct an endpoint for a protocol familly that is neither AF_INET or AF_INET6"); } } if (address.empty()) kissnet_fatal_error("Couldn't construct endpoint from sockaddr(_storage) struct"); } }; //Wrap "system calls" here to avoid conflicts with the names used in the socket class ///socket() inline auto syscall_socket = [](int af, int type, int protocol) { return ::socket(af, type, protocol); }; ///select() inline auto syscall_select = [](int nfds, fd_set* readfds, fd_set* writefds, fd_set* exceptfds, struct timeval* timeout) { return ::select(nfds, readfds, writefds, exceptfds, timeout); }; ///recv() inline auto syscall_recv = [](SOCKET s, char* buff, buffsize_t len, int flags) { return ::recv(s, buff, len, flags); }; ///send() inline auto syscall_send = [](SOCKET s, const char* buff, buffsize_t len, int flags) { return ::send(s, buff, len, flags); }; ///bind() inline auto syscall_bind = [](SOCKET s, const struct sockaddr* name, socklen_t namelen) { return ::bind(s, name, namelen); }; ///connect() inline auto syscall_connect = [](SOCKET s, const struct sockaddr* name, socklen_t namelen) { return ::connect(s, name, namelen); }; ///listen() inline auto syscall_listen = [](SOCKET s, int backlog) { return ::listen(s, backlog); }; ///accept() inline auto syscall_accept = [](SOCKET s, struct sockaddr* addr, socklen_t* addrlen) { return ::accept(s, addr, addrlen); }; ///shutdown() inline auto syscall_shutdown = [](SOCKET s) { return ::shutdown(s, SHUT_RDWR); }; ///Represent the status of a socket as returned by a socket operation (send, received). Implicitly convertible to bool struct socket_status { ///Enumeration of socket status, with a 1 byte footprint enum values : int8_t { errored = 0x0, valid = 0x1, cleanly_disconnected = 0x2, non_blocking_would_have_blocked = 0x3, timed_out = 0x4 /* ... any other info on a "still valid socket" goes here ... */ }; ///Actual value of the socket_status. const values value; ///Use the default constructor socket_status() : value { errored } { } ///Construct a "errored/valid" status for a true/false explicit socket_status(bool state) : value(values(state ? valid : errored)) { } socket_status(values v) : value(v) { } ///Copy socket status by default socket_status(const socket_status&) = default; ///Move socket status by default socket_status(socket_status&&) = default; ///implicitly convert this object to const bool (as the status should not change) operator bool() const { //See the above enum: every value <= 0 correspond to an error, and will return false. Every value > 0 returns true return value > 0; } int8_t get_value() { return value; } bool operator==(values v) { return v == value; } }; #ifdef KISSNET_USE_OPENSSL #if OPENSSL_VERSION_NUMBER < 0x10100000L static std::shared_ptr> SSL_lock_cs; class ThreadSafe_SSL { public: ThreadSafe_SSL() { SSL_lock_cs = std::make_shared>(CRYPTO_num_locks()); CRYPTO_set_locking_callback((void (*)(int, int, const char*, int)) win32_locking_callback); } ~ThreadSafe_SSL() { CRYPTO_set_locking_callback(nullptr); } private: static void win32_locking_callback(int mode, int type, const char* file, int line) { auto& locks = *SSL_lock_cs; if (mode & CRYPTO_LOCK) { locks[type].lock(); } else { locks[type].unlock(); } } }; #endif class Initialize_SSL { public: Initialize_SSL() { #if OPENSSL_VERSION_NUMBER < 0x1010001fL SSL_load_error_strings(); SSL_library_init(); #else OPENSSL_init_ssl( OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); OPENSSL_init_crypto( OPENSSL_INIT_LOAD_CONFIG | OPENSSL_INIT_ADD_ALL_CIPHERS | OPENSSL_INIT_ADD_ALL_DIGESTS, nullptr); #endif } ~Initialize_SSL() { #if OPENSSL_VERSION_NUMBER < 0x1010001fL ERR_free_strings(); #endif } #if OPENSSL_VERSION_NUMBER < 0x10100000L private: ThreadSafe_SSL thread_setup; #endif }; static Initialize_SSL InitializeSSL; #endif ///Class that represent a socket template class socket { ///Represent a number of bytes with a status information. Some of the methods of this class returns this. using bytes_with_status = std::tuple; ///OS specific stuff. payload we have to hold onto for RAII management of the Operating System's socket library (e.g. Windows Socket API WinSock2) KISSNET_OS_SPECIFIC; ///operatic-system type for a socket object SOCKET sock = INVALID_SOCKET; #ifdef KISSNET_USE_OPENSSL SSL* pSSL = nullptr; SSL_CTX* pContext = nullptr; #endif ///Location where this socket is bound endpoint bind_loc = {}; ///Address information structures addrinfo getaddrinfo_hints = {}; addrinfo* getaddrinfo_results = nullptr; addrinfo* socket_addrinfo = nullptr; void initialize_addrinfo() { int type {}; int iprotocol {}; if constexpr (sock_proto == protocol::tcp || sock_proto == protocol::tcp_ssl) { type = SOCK_STREAM; iprotocol = IPPROTO_TCP; } else if constexpr (sock_proto == protocol::udp) { type = SOCK_DGRAM; iprotocol = IPPROTO_UDP; } getaddrinfo_hints = {}; getaddrinfo_hints.ai_family = AF_UNSPEC; getaddrinfo_hints.ai_socktype = type; getaddrinfo_hints.ai_protocol = iprotocol; getaddrinfo_hints.ai_flags = AI_ADDRCONFIG; } ///Create and connect to socket socket_status connect(addrinfo* addr, int64_t timeout, bool createsocket) { if constexpr (sock_proto == protocol::tcp || sock_proto == protocol::tcp_ssl) //only TCP is a connected protocol { if (createsocket) { close(); socket_addrinfo = nullptr; sock = syscall_socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); } if (sock == INVALID_SOCKET) return socket_status::errored; socket_addrinfo = addr; if (timeout > 0) set_non_blocking(true); int error = syscall_connect(sock, addr->ai_addr, socklen_t(addr->ai_addrlen)); if (error == SOCKET_ERROR) { error = get_error_code(); if (error == EWOULDBLOCK || error == EAGAIN || error == EINPROGRESS) { struct timeval tv; tv.tv_sec = static_cast(timeout / 1000); tv.tv_usec = 1000 * static_cast(timeout % 1000); fd_set fd_write, fd_except; ; FD_ZERO(&fd_write); FD_SET(sock, &fd_write); FD_ZERO(&fd_except); FD_SET(sock, &fd_except); int ret = syscall_select(static_cast(sock) + 1, NULL, &fd_write, &fd_except, &tv); if (ret == -1) error = get_error_code(); else if (ret == 0) error = ETIMEDOUT; else { socklen_t errlen = sizeof(error); if (getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &errlen) != 0) kissnet_fatal_error("getting socket error returned an error"); } } } if (timeout > 0) set_non_blocking(false); if (error == 0) { return socket_status::valid; } else { close(); socket_addrinfo = nullptr; return socket_status::errored; } } kissnet_fatal_error("connect called for non-tcp socket"); } ///sockaddr struct sockaddr_storage socket_input = {}; socklen_t socket_input_socklen = 0; public: ///Construct an invalid socket //socket() = default; socket() : dummy{42} {}; // clockkit: for -Weffc++ ///socket<> isn't copyable socket(const socket&) = delete; ///socket<> isn't copyable socket& operator=(const socket&) = delete; ///Move constructor. socket<> isn't copyable socket(socket&& other) noexcept { KISSNET_OS_SPECIFIC_PAYLOAD_NAME = std::move(other.KISSNET_OS_SPECIFIC_PAYLOAD_NAME); bind_loc = std::move(other.bind_loc); sock = std::move(other.sock); socket_input = std::move(other.socket_input); socket_input_socklen = std::move(other.socket_input_socklen); getaddrinfo_results = std::move(other.getaddrinfo_results); socket_addrinfo = std::move(other.socket_addrinfo); #ifdef KISSNET_USE_OPENSSL pSSL = other.pSSL; pContext = other.pContext; other.pSSL = nullptr; other.pContext = nullptr; #endif other.sock = INVALID_SOCKET; other.getaddrinfo_results = nullptr; other.socket_addrinfo = nullptr; } ///Move assign operation socket& operator=(socket&& other) noexcept { if (this != &other) { if (!(sock < 0) || sock != INVALID_SOCKET) closesocket(sock); KISSNET_OS_SPECIFIC_PAYLOAD_NAME = std::move(other.KISSNET_OS_SPECIFIC_PAYLOAD_NAME); bind_loc = std::move(other.bind_loc); sock = std::move(other.sock); socket_input = std::move(other.socket_input); socket_input_socklen = std::move(other.socket_input_socklen); getaddrinfo_results = std::move(other.getaddrinfo_results); socket_addrinfo = std::move(other.socket_addrinfo); #ifdef KISSNET_USE_OPENSSL pSSL = other.pSSL; pContext = other.pContext; other.pSSL = nullptr; other.pContext = nullptr; #endif other.sock = INVALID_SOCKET; other.getaddrinfo_results = nullptr; other.socket_addrinfo = nullptr; } return *this; } ///Return true if the underlying OS provided socket representation (file descriptor, handle...). Both socket are pointing to the same thing in this case bool operator==(const socket& other) const { return sock == other.sock; } ///Return true if socket is valid. If this is false, you probably shouldn't attempt to send/receive anything, it will probably explode in your face! bool is_valid() const { return sock != INVALID_SOCKET; } inline operator bool() const { return is_valid(); } ///Construct socket and (if applicable) connect to the endpoint socket(endpoint bind_to) : dummy { 42 }, // clockkit, for -Weffc++ bind_loc { std::move(bind_to) } { //operating system related housekeeping KISSNET_OS_INIT; //Do we use streams or datagrams initialize_addrinfo(); if (getaddrinfo(bind_loc.address.c_str(), std::to_string(bind_loc.port).c_str(), &getaddrinfo_hints, &getaddrinfo_results) != 0) { kissnet_fatal_error("getaddrinfo failed!"); } for (auto* addr = getaddrinfo_results; addr; addr = addr->ai_next) { sock = syscall_socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); if (sock != INVALID_SOCKET) { socket_addrinfo = addr; break; } } if (sock == INVALID_SOCKET) { kissnet_fatal_error("unable to create socket!"); } } ///Construct a socket from an operating system socket, an additional endpoint to remember from where we are socket(SOCKET native_sock, endpoint bind_to) : dummy { 42 }, // clockkit, for -Weffc++ sock { native_sock }, bind_loc(std::move(bind_to)) { KISSNET_OS_INIT; initialize_addrinfo(); } ///Set the socket in non blocking mode /// \param state By default "true". If put to false, it will set the socket back into blocking, normal mode void set_non_blocking(bool state = true) const { #ifdef _WIN32 ioctl_setting set = state ? 1 : 0; if (ioctlsocket(sock, FIONBIO, &set) < 0) #else const auto flags = fcntl(sock, F_GETFL, 0); const auto newflags = state ? flags | O_NONBLOCK : flags ^ O_NONBLOCK; if (fcntl(sock, F_SETFL, newflags) < 0) #endif kissnet_fatal_error("setting socket to nonblock returned an error"); } ///Set the socket option for broadcasts /// \param state By default "true". If put to false, it will disable broadcasts void set_broadcast(bool state = true) const { const int broadcast = state ? 1 : 0; if (setsockopt(sock, SOL_SOCKET, SO_BROADCAST, reinterpret_cast(&broadcast), sizeof(broadcast)) != 0) kissnet_fatal_error("setting socket broadcast mode returned an error"); } /// Set the socket option for TCPNoDelay /// \param state By default "true". If put to false, it will disable TCPNoDelay void set_tcp_no_delay(bool state = true) const { if constexpr (sock_proto == protocol::tcp) { const int tcpnodelay = state ? 1 : 0; if (setsockopt(sock, SOL_TCP, TCP_NODELAY, reinterpret_cast(&tcpnodelay), sizeof(tcpnodelay)) != 0) kissnet_fatal_error("setting socket tcpnodelay mode returned an error"); } } /// Get socket status socket_status get_status() const { int sockerror = 0; socklen_t errlen = sizeof(sockerror); if (getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&sockerror), &errlen) != 0) kissnet_fatal_error("getting socket error returned an error"); return sockerror == SOCKET_ERROR ? socket_status::errored : socket_status::valid; } ///Bind socket locally using the address and port of the endpoint void bind() { if (syscall_bind(sock, static_cast(socket_addrinfo->ai_addr), socklen_t(socket_addrinfo->ai_addrlen)) == SOCKET_ERROR) { kissnet_fatal_error("bind() failed\n"); } } ///(For TCP) connect to the endpoint as client socket_status connect(int64_t timeout = 0) { if constexpr (sock_proto == protocol::tcp) //only TCP is a connected protocol { // try to connect to existing native socket, if any. auto curr_addr = socket_addrinfo; if (connect(curr_addr, timeout, false) != socket_status::valid) { // try to create/connect native socket for one of the other addrinfo, if any for (auto* addr = getaddrinfo_results; addr; addr = addr->ai_next) { if (addr == curr_addr) continue; // already checked if (connect(addr, timeout, true) == socket_status::valid) break; // success } } if (sock == INVALID_SOCKET) kissnet_fatal_error("unable to create connectable socket!"); return socket_status::valid; } #ifdef KISSNET_USE_OPENSSL else if constexpr (sock_proto == protocol::tcp_ssl) //only TCP is a connected protocol { // try to connect to existing native socket, if any. auto curr_addr = socket_addrinfo; if (connect(curr_addr, timeout, false) != socket_status::valid) { // try to create/connect native socket for one of the other addrinfo, if any for (auto* addr = getaddrinfo_results; addr; addr = addr->ai_next) { if (addr == curr_addr) continue; // already checked if (connect(addr, timeout, true) == socket_status::valid) break; // success } } if (sock == INVALID_SOCKET) kissnet_fatal_error("unable to create connectable socket!"); auto* pMethod = #if (OPENSSL_VERSION_NUMBER < 0x10100000L) TLSv1_2_client_method(); #else TLS_client_method(); #endif pContext = SSL_CTX_new(pMethod); pSSL = SSL_new(pContext); if (!pSSL) return socket_status::errored; if (!(static_cast(SSL_set_fd(pSSL, sock)))) return socket_status::errored; if (SSL_connect(pSSL) != 1) return socket_status::errored; return socket_status::valid; } #endif } ///(for TCP= setup socket to listen to connection. Need to be called on binded socket, before being able to accept() void listen() { if constexpr (sock_proto == protocol::tcp) { if (syscall_listen(sock, SOMAXCONN) == SOCKET_ERROR) { kissnet_fatal_error("listen failed\n"); } } } ///(for TCP) Wait for incoming connection, return socket connect to the client. Blocking. socket accept() { if constexpr (sock_proto != protocol::tcp) { return { INVALID_SOCKET, {} }; } sockaddr_storage socket_address; SOCKET s; socklen_t size = sizeof socket_address; if ((s = syscall_accept(sock, reinterpret_cast(&socket_address), &size)) == INVALID_SOCKET) { const auto error = get_error_code(); switch (error) { case EWOULDBLOCK: //if socket "would have blocked" from the call, ignore case EINTR: //if blocking call got interrupted, ignore; return {}; } kissnet_fatal_error("accept() returned an invalid socket\n"); } return { s, endpoint(reinterpret_cast(&socket_address)) }; } void close() { if (sock != INVALID_SOCKET) { #ifdef KISSNET_USE_OPENSSL if constexpr (sock_proto == protocol::tcp_ssl) { if (pSSL) { SSL_set_shutdown(pSSL, SSL_RECEIVED_SHUTDOWN | SSL_SENT_SHUTDOWN); SSL_shutdown(pSSL); SSL_free(pSSL); if (pContext) SSL_CTX_free(pContext); } } #endif closesocket(sock); } sock = INVALID_SOCKET; } void shutdown() { if (sock != INVALID_SOCKET) { syscall_shutdown(sock); } } ///Close socket on destruction ~socket() { close(); if (getaddrinfo_results) freeaddrinfo(getaddrinfo_results); } ///Select socket with timeout socket_status select(int fds, int64_t timeout) { fd_set fd_read, fd_write, fd_except; ; struct timeval tv; tv.tv_sec = static_cast(timeout / 1000); tv.tv_usec = 1000 * static_cast(timeout % 1000); if (fds & fds_read) { FD_ZERO(&fd_read); FD_SET(sock, &fd_read); } if (fds & fds_write) { FD_ZERO(&fd_write); FD_SET(sock, &fd_write); } if (fds & fds_except) { FD_ZERO(&fd_except); FD_SET(sock, &fd_except); } int ret = syscall_select(static_cast(sock) + 1, fds & fds_read ? &fd_read : NULL, fds & fds_write ? &fd_write : NULL, fds & fds_except ? &fd_except : NULL, &tv); if (ret == -1) return socket_status::errored; else if (ret == 0) return socket_status::timed_out; return socket_status::valid; } template bytes_with_status send(const buffer& buff, const size_t length = buff_size, addr_collection* addr = nullptr) { assert(buff_size >= length); return send(buff.data(), length, addr); } ///Send some bytes through the pipe bytes_with_status send(const std::byte* read_buff, size_t length, addr_collection* addr = nullptr) { auto received_bytes { 0 }; if constexpr (sock_proto == protocol::tcp) { received_bytes = syscall_send(sock, reinterpret_cast(read_buff), static_cast(length), 0); } #ifdef KISSNET_USE_OPENSSL else if constexpr (sock_proto == protocol::tcp_ssl) { received_bytes = SSL_write(pSSL, reinterpret_cast(read_buff), static_cast(length)); } #endif else if constexpr (sock_proto == protocol::udp) { if (addr) { // struct sockaddr_in* dest = (SOCKADDR_IN*)(&addr->adrinf); // std::cerr << "kissnet.hpp ClockServer sendto " << inet_ntoa(dest->sin_addr) << "\n"; received_bytes = sendto(sock, reinterpret_cast(read_buff), static_cast(length), 0, reinterpret_cast(&addr->adrinf) , addr->sock_size); } else { // struct sockaddr_in* dest = (SOCKADDR_IN*)(socket_addrinfo->ai_addr); // std::cerr << "kissnet.hpp ClockClient sendto " << inet_ntoa(dest->sin_addr) << "\n"; received_bytes = sendto(sock, reinterpret_cast(read_buff), static_cast(length), 0, static_cast(socket_addrinfo->ai_addr), socklen_t(socket_addrinfo->ai_addrlen)); } } if (received_bytes < 0) { if (get_error_code() == EWOULDBLOCK) { return { 0, socket_status::non_blocking_would_have_blocked }; } return { 0, socket_status::errored }; } return { received_bytes, socket_status::valid }; } ///receive bytes inside the buffer, return the number of bytes you got. You can choose to write inside the buffer at a specific start offset (in number of bytes) template bytes_with_status recv(buffer& write_buff, size_t start_offset = 0, addr_collection* addr_info = nullptr) { auto received_bytes = 0; if constexpr (sock_proto == protocol::tcp) { received_bytes = syscall_recv(sock, reinterpret_cast(write_buff.data()) + start_offset, static_cast(buff_size - start_offset), 0); } #ifdef KISSNET_USE_OPENSSL else if constexpr (sock_proto == protocol::tcp_ssl) { received_bytes = SSL_read(pSSL, reinterpret_cast(write_buff.data()) + start_offset, static_cast(buff_size - start_offset)); } #endif else if constexpr (sock_proto == protocol::udp) { socket_input_socklen = sizeof socket_input; received_bytes = ::recvfrom(sock, reinterpret_cast(write_buff.data()) + start_offset, static_cast(buff_size - start_offset), 0, reinterpret_cast(&socket_input), &socket_input_socklen); if (addr_info) { addr_info->adrinf = socket_input; addr_info->sock_size = socket_input_socklen; } } if (received_bytes < 0) { const auto error = get_error_code(); if (error == EWOULDBLOCK) return { 0, socket_status::non_blocking_would_have_blocked }; if (error == EAGAIN) return { 0, socket_status::non_blocking_would_have_blocked }; return { 0, socket_status::errored }; } if (received_bytes == 0) { return { received_bytes, socket_status::cleanly_disconnected }; } return { size_t(received_bytes), socket_status::valid }; } ///receive up-to len bytes inside the memory location pointed by buffer bytes_with_status recv(std::byte* buffer, size_t len, bool wait = true, addr_collection* addr_info = nullptr) { auto received_bytes = 0; if constexpr (sock_proto == protocol::tcp) { int flags; if (wait) flags = MSG_WAITALL; else { #ifdef _WIN32 flags = 0; // MSG_DONTWAIT not avail on windows, need to make socket nonblockingto emulate set_non_blocking(true); #else flags = MSG_DONTWAIT; #endif } received_bytes = syscall_recv(sock, reinterpret_cast(buffer), static_cast(len), flags); #ifdef _WIN32 set_non_blocking(false); #endif } #ifdef KISSNET_USE_OPENSSL else if constexpr (sock_proto == protocol::tcp_ssl) { received_bytes = SSL_read(pSSL, reinterpret_cast(buffer), static_cast(len)); } #endif else if constexpr (sock_proto == protocol::udp) { socket_input_socklen = sizeof socket_input; received_bytes = ::recvfrom(sock, reinterpret_cast(buffer), static_cast(len), 0, reinterpret_cast(&socket_input), &socket_input_socklen); if (addr_info) { addr_info->adrinf = socket_input; addr_info->sock_size = socket_input_socklen; } } if (received_bytes < 0) { const auto error = get_error_code(); if (error == EWOULDBLOCK) return { 0, socket_status::non_blocking_would_have_blocked }; if (error == EAGAIN) return { 0, socket_status::non_blocking_would_have_blocked }; return { 0, socket_status::errored }; } if (received_bytes == 0) { return { received_bytes, socket_status::cleanly_disconnected }; } return { size_t(received_bytes), socket_status::valid }; } ///Return the endpoint where this socket is talking to endpoint get_bind_loc() const { return bind_loc; } ///Return an endpoint that originated the data in the last recv endpoint get_recv_endpoint() const { if constexpr (sock_proto == protocol::tcp) { return get_bind_loc(); } if constexpr (sock_proto == protocol::udp) { return { (sockaddr*)&socket_input }; } } ///Return the number of bytes available inside the socket size_t bytes_available() const { static ioctl_setting size = 0; const auto status = ioctlsocket(sock, FIONREAD, &size); if (status < 0) { kissnet_fatal_error("ioctlsocket status is negative when getting FIONREAD\n"); } return size > 0 ? size : 0; } ///Return the protocol used by this socket static protocol get_protocol() { return sock_proto; } }; ///Alias for socket using tcp_socket = socket; #ifdef KISSNET_USE_OPENSSL ///Alias for socket using tcp_ssl_socket = socket; #endif //KISSNET_USE_OPENSSL ///Alias for socket using udp_socket = socket; } //cleanup preprocessor macros #undef KISSNET_OS_SPECIFIC_PAYLOAD_NAME #undef KISSNET_OS_SPECIFIC #undef KISSNET_OS_INIT #undef kissnet_fatal_error #endif //KISS_NET