From 0fecb95833539c3af854e255ef05ff9e77deff86 Mon Sep 17 00:00:00 2001 From: Daniele Rogora Date: Thu, 14 Sep 2023 10:37:00 +0200 Subject: [PATCH 1/4] Refactor network sockets and introduce Sock class This commit follows Bitcoin PR20788 (and probably several followups), and introduces a Sock class to handle the lifetime of network sockets and associated SSL connection handlers. --- src/Makefile.am | 2 + src/gtest/test_asyncproofverifier.cpp | 2 +- src/gtest/test_mempool.cpp | 2 +- src/metrics.cpp | 2 +- src/net.cpp | 231 +++++++++++------------ src/net.h | 19 +- src/netbase.cpp | 226 +++++++++------------- src/netbase.h | 20 +- src/test/DoS_tests.cpp | 8 +- src/util/sock.cpp | 260 ++++++++++++++++++++++++++ src/util/sock.h | 122 ++++++++++++ src/zen/tlsmanager.cpp | 70 ++++--- src/zen/tlsmanager.h | 14 +- 13 files changed, 653 insertions(+), 325 deletions(-) create mode 100644 src/util/sock.cpp create mode 100644 src/util/sock.h diff --git a/src/Makefile.am b/src/Makefile.am index 4d202d64f1..4467342663 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -212,6 +212,7 @@ BITCOIN_CORE_H = \ undo.h \ util.h \ utilmoneystr.h \ + util/sock.h \ utilstrencodings.h \ utiltime.h \ validationinterface.h \ @@ -397,6 +398,7 @@ libbitcoin_common_a_SOURCES = \ script/script_error.cpp \ script/sign.cpp \ script/standard.cpp \ + util/sock.cpp \ $(BITCOIN_CORE_H) \ $(LIBZCASH_H) diff --git a/src/gtest/test_asyncproofverifier.cpp b/src/gtest/test_asyncproofverifier.cpp index 3bbc87600e..2a64c2dc65 100644 --- a/src/gtest/test_asyncproofverifier.cpp +++ b/src/gtest/test_asyncproofverifier.cpp @@ -27,7 +27,7 @@ class AsyncProofVerifierTestSuite : public ::testing::Test mempool.reset(new CTxMemPool(::minRelayTxFee, DEFAULT_MAX_MEMPOOL_SIZE_MB * 1000000)); connman.reset(new CConnman()); - dummyNode.reset(new CNode(INVALID_SOCKET, CAddress(), "", true)); + dummyNode.reset(new CNode(nullptr, CAddress(), "", true)); dummyNode->id = 7; sidechain.creationBlockHeight = 100; diff --git a/src/gtest/test_mempool.cpp b/src/gtest/test_mempool.cpp index 630a68e4ca..7b99a17536 100644 --- a/src/gtest/test_mempool.cpp +++ b/src/gtest/test_mempool.cpp @@ -430,7 +430,7 @@ class CNodeExt : public CNode } CNodeExt(): - CNode(INVALID_SOCKET, CAddress(ip(0xa0b0c002)), "", true) + CNode(nullptr, CAddress(ip(0xa0b0c002)), "", true) { } diff --git a/src/metrics.cpp b/src/metrics.cpp index 070d407d24..f1356f5807 100644 --- a/src/metrics.cpp +++ b/src/metrics.cpp @@ -223,7 +223,7 @@ int printStats(bool mining) { LOCK2(cs_main, connman->cs_vNodes); connections = connman->vNodes.size(); - tlsConnections = std::count_if(connman->vNodes.begin(), connman->vNodes.end(), [](CNode* n) {return n->ssl != NULL;}); + tlsConnections = std::count_if(connman->vNodes.begin(), connman->vNodes.end(), [](CNode* n) {return n->GetSSL() != nullptr;}); } unsigned long mempool_count = mempool->size(); /* diff --git a/src/net.cpp b/src/net.cpp index 70ffe9a892..7b78d2cb74 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -55,6 +55,7 @@ #error "ERROR: Your OpenSSL version does not support TLS v1.2" #endif +using namespace std; // // Global state variables // @@ -65,7 +66,6 @@ map mapLocalHost; static bool vfLimited[NET_MAX] = {}; uint64_t nLocalHostNonce = 0; //// This is part of CNode CAddrMan addrman; -TLSManager tlsmanager = TLSManager(); std::map mapRelay; std::deque > vRelayExpiration; CCriticalSection cs_mapRelay; @@ -354,20 +354,23 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest) pszDest ? 0.0 : (double)(GetTime() - addrConnect.nTime)/3600.0); // Connect - SOCKET hSocket; + std::unique_ptr sock = CreateSock(addrConnect); + if (!sock) { + return nullptr; + } bool proxyConnectionFailed = false; - if (pszDest ? ConnectSocketByName(addrConnect, hSocket, pszDest, Params().GetDefaultPort(), nConnectTimeout, &proxyConnectionFailed) : - ConnectSocket(addrConnect, hSocket, nConnectTimeout, &proxyConnectionFailed)) + if ( + pszDest ? ConnectSocketByName(addrConnect, *sock, pszDest, Params().GetDefaultPort(), nConnectTimeout, &proxyConnectionFailed) : + ConnectSocket(addrConnect, *sock, nConnectTimeout, &proxyConnectionFailed)) { - if (!IsSelectableSocket(hSocket)) { + if (!sock->IsSelectable()) { LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n"); - CloseSocket(hSocket); return NULL; } addrman.Attempt(addrConnect); - SSL *ssl = NULL; + SSL *ssl = nullptr; #ifdef USE_TLS /* TCP connection is ready. Do client side SSL. */ @@ -386,7 +389,8 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest) unsigned long err_code = 0; if (bUseTLS) { - ssl = tlsmanager.connect(hSocket, addrConnect, err_code); + ssl = TLSManager::connect(*sock, addrConnect, err_code); + assert(ssl == sock->GetSSL()); if (!ssl) { if (err_code == TLSManager::SELECT_TIMEDOUT) @@ -403,7 +407,6 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest) LogPrint("tls", "%s():%d - err_code %x, adding connection to %s vNonTLSNodesOutbound list (sz=%d)\n", __func__, __LINE__, err_code, addrConnect.ToStringIP(), vNonTLSNodesOutbound.size()); } - CloseSocket(hSocket); return NULL; } } @@ -423,12 +426,11 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest) else { unsigned long err_code = 0; - ssl = tlsmanager.connect(hSocket, addrConnect, err_code); - if(!ssl) + ssl = TLSManager::connect(*sock, addrConnect, err_code); + if (!ssl) { LogPrint("tls", "%s():%d - err_code %x, connection to %s failed)\n", __func__, __LINE__, err_code, addrConnect.ToStringIP()); - CloseSocket(hSocket); return NULL; } } @@ -440,16 +442,15 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest) { LogPrintf ("TLS: ERROR: Wrong server certificate from %s. Connection will be closed.\n", addrConnect.ToString()); - SSL_shutdown(ssl); - CloseSocket(hSocket); - SSL_free(ssl); + //SSL_shutdown(ssl); + //SSL_free(ssl); return NULL; } } #endif // USE_TLS // Add node - CNode* pnode = new CNode(hSocket, addrConnect, pszDest ? pszDest : "", false, ssl); + CNode* pnode = new CNode(std::move(sock), addrConnect, pszDest ? pszDest : "", false); pnode->AddRef(); { @@ -466,7 +467,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest) addrman.Attempt(addrConnect); } - return NULL; + return nullptr; } void CNode::CloseSocketDisconnect() @@ -476,7 +477,7 @@ void CNode::CloseSocketDisconnect() { LOCK(cs_hSocket); - if (hSocket != INVALID_SOCKET) + if (hSocket) { try { @@ -489,14 +490,13 @@ void CNode::CloseSocketDisconnect() LogPrintf("(node is probably shutting down) disconnecting peer=%d\n", id); } - if (ssl) + if (hSocket->GetSSL()) { unsigned long err_code = 0; - tlsmanager.waitFor(SSL_SHUTDOWN, addr, ssl, 100 /*double of avg roundtrip on decent connection*/, err_code); - SSL_free(ssl); - ssl = NULL; + TLSManager::waitFor(SSL_SHUTDOWN, addr, *hSocket, 100 /*double of avg roundtrip on decent connection*/, err_code); } - CloseSocket(hSocket); + + hSocket.reset(); } } @@ -656,8 +656,9 @@ void CNode::copyStats(CNodeStats &stats) // If ssl != NULL it means TLS connection was established successfully { LOCK(cs_hSocket); - stats.fTLSEstablished = (ssl != NULL) && (SSL_get_state(ssl) == TLS_ST_OK); - stats.fTLSVerified = (ssl != NULL) && ValidatePeerCertificate(ssl); + SSL* ssl = hSocket->GetSSL(); + stats.fTLSEstablished = (ssl != nullptr) && (SSL_get_state(ssl) == TLS_ST_OK); + stats.fTLSVerified = (ssl != nullptr) && ValidatePeerCertificate(ssl); } } #undef X @@ -825,18 +826,23 @@ bool CConnman::StopNode() void CConnman::NetCleanup() { // Close sockets - BOOST_FOREACH(CNode* pnode, vNodes) + for (CNode* pnode: vNodes) { pnode->CloseSocketDisconnect(); - BOOST_FOREACH(ListenSocket& hListenSocket, vhListenSocket) - if (hListenSocket.socket != INVALID_SOCKET) - if (!CloseSocket(hListenSocket.socket)) - LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError())); + } + + for (ListenSocket& hListenSocket: vhListenSocket) { + if (!hListenSocket.sock->Reset()) { + LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError())); + } + } // clean up some globals (to help leak detection) - BOOST_FOREACH(CNode *pnode, vNodes) + for (CNode *pnode: vNodes) { delete pnode; - BOOST_FOREACH(CNode *pnode, vNodesDisconnected) + } + for (CNode *pnode: vNodesDisconnected) { delete pnode; + } vNodes.clear(); vNodesDisconnected.clear(); vhListenSocket.clear(); @@ -859,30 +865,17 @@ void CConnman::SocketSendData(CNode *pnode) const CSerializeData &data = *it; assert(data.size() > pnode->nSendOffset); - bool bIsSSL = false; - int nBytes = 0, nRet = 0; + int nBytes = 0; { LOCK(pnode->cs_hSocket); - if (pnode->hSocket == INVALID_SOCKET) + if (!pnode->hSocket) { LogPrint("net", "Send: connection with %s is already closed\n", pnode->addr.ToString()); break; } - bIsSSL = (pnode->ssl != NULL); - - if (bIsSSL) - { - ERR_clear_error(); // clear the error queue, otherwise we may be reading an old error that occurred previously in the current thread - nBytes = SSL_write(pnode->ssl, &data[pnode->nSendOffset], data.size() - pnode->nSendOffset); - nRet = SSL_get_error(pnode->ssl, nBytes); - } - else - { - nBytes = send(pnode->hSocket, &data[pnode->nSendOffset], data.size() - pnode->nSendOffset, MSG_NOSIGNAL | MSG_DONTWAIT); - nRet = WSAGetLastError(); - } + nBytes = pnode->hSocket->Send(&data[pnode->nSendOffset], data.size() - pnode->nSendOffset, MSG_NOSIGNAL | MSG_DONTWAIT); } if (nBytes > 0) { @@ -909,8 +902,9 @@ void CConnman::SocketSendData(CNode *pnode) { // error // - if (bIsSSL) + if (pnode->GetSSL() != nullptr) { + const int nRet = SSL_get_error(pnode->GetSSL(), nBytes); if (nRet != SSL_ERROR_WANT_READ && nRet != SSL_ERROR_WANT_WRITE) { LogPrintf("ERROR: SSL_write %s; closing connection\n", ERR_error_string(nRet, NULL)); @@ -925,6 +919,7 @@ void CConnman::SocketSendData(CNode *pnode) } else { + const int nRet = WSAGetLastError(); if (nRet != WSAEWOULDBLOCK && nRet != WSAEMSGSIZE && nRet != WSAEINTR && nRet != WSAEINPROGRESS) { LogPrintf("ERROR: send %s; closing connection\n", NetworkErrorString(nRet)); @@ -1101,19 +1096,28 @@ bool CConnman::AttemptToEvictConnection(bool fPreferNewConnection) { } -void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { +void CConnman::AcceptConnection(ListenSocket& hListenSocket) { struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); - SOCKET hSocket = accept(hListenSocket.socket, (struct sockaddr*)&sockaddr, &len); + //SOCKET hSocket = accept(hListenSocket.sock->Get(), (struct sockaddr*)&sockaddr, &len); + //hListenSocket.sock.reset(new Sock(accept(hListenSocket.sock->Get(), (struct sockaddr*)&sockaddr, &len))); + std::unique_ptr sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len); CAddress addr; int nInbound = 0; int nMaxInbound = nMaxConnections - MAX_OUTBOUND_CONNECTIONS; - if (hSocket != INVALID_SOCKET) + if (!sock) + { + int nErr = WSAGetLastError(); + if (nErr != WSAEWOULDBLOCK) + LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr)); + return; + } + + if (sock) if (!addr.SetSockAddr((const struct sockaddr*)&sockaddr)) LogPrintf("Warning: Unknown socket family\n"); - bool whitelisted = hListenSocket.whitelisted || CConnman::IsWhitelistedRange(addr); { LOCK(cs_vNodes); BOOST_FOREACH(CNode* pnode, vNodes) @@ -1121,25 +1125,16 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { nInbound++; } - if (hSocket == INVALID_SOCKET) - { - int nErr = WSAGetLastError(); - if (nErr != WSAEWOULDBLOCK) - LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr)); - return; - } - - if (!IsSelectableSocket(hSocket)) + if (!sock->IsSelectable()) { LogPrintf("connection from %s dropped: non-selectable socket\n", addr.ToString()); - CloseSocket(hSocket); return; } + bool whitelisted = hListenSocket.whitelisted || CConnman::IsWhitelistedRange(addr); if (CNode::IsBanned(addr) && !whitelisted) { LogPrintf("connection from %s dropped (banned)\n", addr.ToString()); - CloseSocket(hSocket); return; } @@ -1148,7 +1143,6 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { if (!AttemptToEvictConnection(whitelisted)) { // No connection to evict, disconnect the new connection LogPrint("net", "failed to find an eviction candidate - connection dropped (full)\n"); - CloseSocket(hSocket); return; } } @@ -1157,15 +1151,15 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { // on all platforms. Set it again here just to be sure. int set = 1; #ifdef WIN32 - setsockopt(hSocket, IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int)); + sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int)); #else - setsockopt(hSocket, IPPROTO_TCP, TCP_NODELAY, (void*)&set, sizeof(int)); + sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (void*)&set, sizeof(int)); #endif - SSL *ssl = NULL; + SSL *ssl = nullptr; - SetSocketNonBlocking(hSocket, true); + sock->SetNonBlocking(); #ifdef USE_TLS /* TCP connection is ready. Do server side SSL. */ @@ -1183,7 +1177,7 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { unsigned long err_code = 0; if (bUseTLS) { - ssl = tlsmanager.accept( hSocket, addr, err_code); + ssl = TLSManager::accept(*sock, addr, err_code); if(!ssl) { if (err_code == TLSManager::SELECT_TIMEDOUT) @@ -1199,7 +1193,6 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { LogPrint("tls", "%s():%d - err_code %x, adding connection from %s vNonTLSNodesInbound list (sz=%d)\n", __func__, __LINE__, err_code, addr.ToStringIP(), vNonTLSNodesInbound.size()); } - CloseSocket(hSocket); return; } } @@ -1219,12 +1212,11 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { else { unsigned long err_code = 0; - ssl = tlsmanager.accept( hSocket, addr, err_code); - if(!ssl) + ssl = TLSManager::accept(*sock, addr, err_code); + if (!ssl) { LogPrint("tls", "%s():%d - err_code %x, failure accepting connection from %s\n", __func__, __LINE__, err_code, addr.ToStringIP()); - CloseSocket(hSocket); return; } } @@ -1237,14 +1229,12 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { LogPrintf ("TLS: ERROR: Wrong client certificate from %s. Connection will be closed.\n", addr.ToString()); SSL_shutdown(ssl); - CloseSocket(hSocket); - SSL_free(ssl); return; } } #endif // USE_TLS - CNode* pnode = new CNode(hSocket, addr, "", true, ssl); + CNode* pnode = new CNode(std::move(sock), addr, "", true); pnode->AddRef(); pnode->fWhitelisted = whitelisted; @@ -1259,8 +1249,8 @@ void CConnman::ThreadNonTLSPoolsCleaner() { while (!interruptNet) { - tlsmanager.cleanNonTLSPool(vNonTLSNodesInbound, cs_vNonTLSNodesInbound); - tlsmanager.cleanNonTLSPool(vNonTLSNodesOutbound, cs_vNonTLSNodesOutbound); + TLSManager::cleanNonTLSPool(connman->vNonTLSNodesInbound, connman->cs_vNonTLSNodesInbound); + TLSManager::cleanNonTLSPool(connman->vNonTLSNodesOutbound, connman->cs_vNonTLSNodesOutbound); if (!interruptNet.sleep_for(std::chrono::milliseconds(DEFAULT_CONNECT_TIMEOUT))) return; } @@ -1345,9 +1335,9 @@ void CConnman::ThreadSocketHandler() SOCKET hSocketMax = 0; bool have_fds = false; - BOOST_FOREACH(const ListenSocket& hListenSocket, vhListenSocket) { - FD_SET(hListenSocket.socket, &fdsetRecv); - hSocketMax = max(hSocketMax, hListenSocket.socket); + BOOST_FOREACH(const ListenSocket& hListenSocket, connman->vhListenSocket) { + FD_SET(hListenSocket.sock->Get(), &fdsetRecv); + hSocketMax = max(hSocketMax, hListenSocket.sock->Get()); have_fds = true; } @@ -1357,11 +1347,11 @@ void CConnman::ThreadSocketHandler() { LOCK(pnode->cs_hSocket); - if (pnode->hSocket == INVALID_SOCKET) + SOCKET socket = pnode->GetSocketFd(); + if (socket == INVALID_SOCKET) continue; - - FD_SET(pnode->hSocket, &fdsetError); - hSocketMax = max(hSocketMax, pnode->hSocket); + FD_SET(socket, &fdsetError); + hSocketMax = max(hSocketMax, socket); have_fds = true; // Implement the following logic: @@ -1383,7 +1373,7 @@ void CConnman::ThreadSocketHandler() { TRY_LOCK(pnode->cs_vSend, lockSend); if (lockSend && !pnode->vSendMsg.empty()) { - FD_SET(pnode->hSocket, &fdsetSend); + FD_SET(socket, &fdsetSend); continue; } } @@ -1392,7 +1382,7 @@ void CConnman::ThreadSocketHandler() if (lockRecv && ( pnode->vRecvMsg.empty() || !pnode->vRecvMsg.front().complete() || pnode->GetTotalRecvSize() <= GetReceiveFloodSize())) - FD_SET(pnode->hSocket, &fdsetRecv); + FD_SET(socket, &fdsetRecv); } } } @@ -1422,7 +1412,7 @@ void CConnman::ThreadSocketHandler() // BOOST_FOREACH(const ListenSocket& hListenSocket, vhListenSocket) { - if (hListenSocket.socket != INVALID_SOCKET && FD_ISSET(hListenSocket.socket, &fdsetRecv)) + if (hListenSocket.sock->Get() != INVALID_SOCKET && FD_ISSET(hListenSocket.sock->Get(), &fdsetRecv)) { AcceptConnection(hListenSocket); } @@ -1443,7 +1433,7 @@ void CConnman::ThreadSocketHandler() if (interruptNet) return; - if (tlsmanager.threadSocketHandler(pnode,fdsetRecv,fdsetSend,fdsetError)==-1){ + if (TLSManager::threadSocketHandler(pnode,fdsetRecv,fdsetSend,fdsetError)==-1){ continue; } @@ -1775,7 +1765,7 @@ bool CConnman::OpenNetworkConnection(const CAddress& addrConnect, CSemaphoreGran else SplitHostPort(string(pszDest), port, strDest); - if (tlsmanager.isNonTLSAddr(strDest, vNonTLSNodesOutbound, cs_vNonTLSNodesOutbound)) + if (TLSManager::isNonTLSAddr(strDest, vNonTLSNodesOutbound, cs_vNonTLSNodesOutbound)) { // Attempt to reconnect in non-TLS mode pnode = ConnectNode(addrConnect, pszDest); @@ -1799,7 +1789,6 @@ bool CConnman::OpenNetworkConnection(const CAddress& addrConnect, CSemaphoreGran return true; } - void CConnman::ThreadMessageHandler() { SetThreadPriority(THREAD_PRIORITY_BELOW_NORMAL); @@ -1899,14 +1888,15 @@ bool CConnman::BindListenPort(const CService &addrBind, string& strError, bool f return false; } - SOCKET hListenSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP); - if (hListenSocket == INVALID_SOCKET) + //SOCKET hListenSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP); + std::unique_ptr sock = CreateSock(addrBind); + if (!sock) { strError = strprintf("Error: Couldn't open socket for incoming connections (socket returned error %s)", NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError); return false; } - if (!IsSelectableSocket(hListenSocket)) + if (!sock->IsSelectable()) { strError = "Error: Couldn't create a listenable socket for incoming connections"; LogPrintf("%s\n", strError); @@ -1916,16 +1906,16 @@ bool CConnman::BindListenPort(const CService &addrBind, string& strError, bool f #ifndef WIN32 #ifdef SO_NOSIGPIPE // Different way of disabling SIGPIPE on BSD - setsockopt(hListenSocket, SOL_SOCKET, SO_NOSIGPIPE, (void*)&nOne, sizeof(int)); + sock->SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&nOne, sizeof(int)); #endif // Allow binding if the port is still in TIME_WAIT state after // the program was closed and restarted. - setsockopt(hListenSocket, SOL_SOCKET, SO_REUSEADDR, (void*)&nOne, sizeof(int)); + sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (void*)&nOne, sizeof(int)); // Disable Nagle's algorithm - setsockopt(hListenSocket, IPPROTO_TCP, TCP_NODELAY, (void*)&nOne, sizeof(int)); + sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (void*)&nOne, sizeof(int)); #else - setsockopt(hListenSocket, SOL_SOCKET, SO_REUSEADDR, (const char*)&nOne, sizeof(int)); - setsockopt(hListenSocket, IPPROTO_TCP, TCP_NODELAY, (const char*)&nOne, sizeof(int)); + sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (const char*)&nOne, sizeof(int)); + sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (const char*)&nOne, sizeof(int)); #endif // Set to non-blocking, incoming connections will also inherit this @@ -1934,7 +1924,7 @@ bool CConnman::BindListenPort(const CService &addrBind, string& strError, bool f // On Linux, the new socket returned by accept() does not inherit file // status flags such as O_NONBLOCK and O_ASYNC from the listening // socket. http://man7.org/linux/man-pages/man2/accept.2.html - if (!SetSocketNonBlocking(hListenSocket, true)) { + if (!sock->SetNonBlocking()) { strError = strprintf("BindListenPort: Setting listening socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError); return false; @@ -1945,18 +1935,18 @@ bool CConnman::BindListenPort(const CService &addrBind, string& strError, bool f if (addrBind.IsIPv6()) { #ifdef IPV6_V6ONLY #ifdef WIN32 - setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&nOne, sizeof(int)); + sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&nOne, sizeof(int)); #else - setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_V6ONLY, (void*)&nOne, sizeof(int)); + sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (void*)&nOne, sizeof(int)); #endif #endif #ifdef WIN32 int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; - setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); + sock->SetSockOpt(IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); #endif } - if (::bind(hListenSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) + if (sock->Bind((struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) { int nErr = WSAGetLastError(); if (nErr == WSAEADDRINUSE) @@ -1964,21 +1954,19 @@ bool CConnman::BindListenPort(const CService &addrBind, string& strError, bool f else strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToString(), NetworkErrorString(nErr)); LogPrintf("%s\n", strError); - CloseSocket(hListenSocket); return false; } - LogPrintf("Bound to %s\n", addrBind.ToString()); + LogPrintf("Bound to %s on sock %d\n", addrBind.ToString(), sock->Get()); // Listen for incoming connections - if (listen(hListenSocket, SOMAXCONN) == SOCKET_ERROR) + if (sock->Listen(SOMAXCONN) == SOCKET_ERROR) { strError = strprintf(_("Error: Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError); - CloseSocket(hListenSocket); return false; } - vhListenSocket.push_back(ListenSocket(hListenSocket, fWhitelisted)); + vhListenSocket.emplace_back(std::move(sock), fWhitelisted); if (addrBind.IsRoutable() && fDiscover && !fWhitelisted) AddLocal(addrBind, LOCAL_BIND); @@ -2062,19 +2050,19 @@ void CConnman::StartNode(CScheduler& scheduler, const Options& connOptions) } if (pnodeLocalHost == nullptr) - pnodeLocalHost = std::make_unique(INVALID_SOCKET, CAddress(CService("127.0.0.1", 0), nLocalServices)); + pnodeLocalHost = std::make_unique(nullptr, CAddress(CService("127.0.0.1", 0), nLocalServices)); Discover(); #ifdef USE_TLS - if (!tlsmanager.prepareCredentials()) + if (!TLSManager::prepareCredentials()) { LogPrintf("TLS: ERROR: %s: %s: Credentials weren't loaded. Node can't be started.\n", __FILE__, __func__); return; } - if (!tlsmanager.initialize()) + if (!TLSManager::initialize()) { LogPrintf("TLS: ERROR: %s: %s: TLS initialization failed. Node can't be started.\n", __FILE__, __func__); return; @@ -2346,13 +2334,12 @@ NodeId CConnman::GetNewNodeId() return nLastNodeId.fetch_add(1, std::memory_order_relaxed); } -CNode::CNode(SOCKET hSocketIn, const CAddress& addrIn, const std::string& addrNameIn, bool fInboundIn, SSL *sslIn) : +CNode::CNode(std::unique_ptr&& sock, const CAddress& addrIn, const std::string& addrNameIn, bool fInboundIn) : ssSend{SER_NETWORK, INIT_PROTO_VERSION}, addrKnown{5000, 0.001}, setInventoryKnown{connman->GetSendBufferSize() / 1000}, - hSocket{hSocketIn} + hSocket{std::move(sock)} { - ssl = sslIn; nServices = 0; nRecvVersion = INIT_PROTO_VERSION; nLastSend = 0; @@ -2401,7 +2388,7 @@ CNode::CNode(SOCKET hSocketIn, const CAddress& addrIn, const std::string& addrNa } // Be shy and don't send version until we hear - if (hSocket != INVALID_SOCKET && !fInbound) + if (hSocket && !fInbound) PushVersion(); GetNodeSignals().InitializeNode(GetId(), this); @@ -2454,17 +2441,13 @@ CNode::~CNode() // No need to make a lock on cs_hSocket, because before deletion CNode object is removed from the vNodes vector, so any other thread hasn't access to it. // Removal is synchronized with read and write routines, so all of them will be completed to this moment. - if (hSocket != INVALID_SOCKET) + if (hSocket) { - if (ssl) + if (GetSSL()) { unsigned long err_code = 0; - tlsmanager.waitFor(SSL_SHUTDOWN, addr, ssl, 0 /*no retries here make no sense on destructor*/, err_code); - SSL_free(ssl); - ssl = NULL; + TLSManager::waitFor(SSL_SHUTDOWN, addr, *hSocket, 0 /*no retries here make no sense on destructor*/, err_code); } - - CloseSocket(hSocket); } if (pfilter) diff --git a/src/net.h b/src/net.h index 10e2b8587d..0cef8fc5e9 100644 --- a/src/net.h +++ b/src/net.h @@ -255,12 +255,10 @@ class CNetMessage { class CNode { public: - // OpenSSL - SSL *ssl; // socket + std::unique_ptr hSocket; uint64_t nServices; - SOCKET hSocket; CCriticalSection cs_hSocket; CDataStream ssSend; size_t nSendSize; // total size of all vSendMsg entries @@ -364,7 +362,7 @@ class CNode // Whether a ping is requested. bool fPingQueued; - CNode(SOCKET hSocketIn, const CAddress &addrIn, const std::string &addrNameIn = "", bool fInboundIn = false, SSL *sslIn = NULL); + CNode(std::unique_ptr &&sock, const CAddress &addrIn, const std::string &addrNameIn = "", bool fInboundIn = false); ~CNode(); CNode(CNode&&) = delete; @@ -378,6 +376,13 @@ class CNode public: + SOCKET GetSocketFd() const { + return hSocket->Get(); + } + SSL* GetSSL() const { + return hSocket->GetSSL(); + } + NodeId GetId() const { return id; } @@ -707,10 +712,10 @@ class CAddrDB //// This definition can be moved into CConnman after boost::thread refactoring struct ListenSocket { - SOCKET socket; + std::shared_ptr sock; bool whitelisted; - ListenSocket(SOCKET socket, bool whitelisted) : socket(socket), whitelisted(whitelisted) {} + ListenSocket(std::shared_ptr sock, bool whitelisted) : sock(sock), whitelisted(whitelisted) {} }; /** Used to pass flags to the Bind() function */ @@ -770,7 +775,7 @@ class CConnman { void ProcessOneShot(); bool OpenNetworkConnection(const CAddress& addrConnect, CSemaphoreGrant *grantOutbound = NULL, const char *strDest = NULL, bool fOneShot = false); - void AcceptConnection(const ListenSocket& hListenSocket); + void AcceptConnection(ListenSocket& hListenSocket); CNode* FindNode(const CNetAddr& ip); CNode* FindNode(const CSubNet& subNet); CNode* FindNode(const std::string& addrName); diff --git a/src/netbase.cpp b/src/netbase.cpp index 9245b77510..b900eb5c85 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -256,7 +256,7 @@ struct timeval MillisToTimeval(int64_t nTimeout) * * @note This function requires that hSocket is in non-blocking mode. */ -bool static InterruptibleRecv(uint8_t* data, size_t len, int timeout, SOCKET& hSocket) +bool static InterruptibleRecv(uint8_t* data, size_t len, int timeout, Sock& sock) { int64_t curTime = GetTimeMillis(); int64_t endTime = curTime + timeout; @@ -270,7 +270,8 @@ bool static InterruptibleRecv(uint8_t* data, size_t len, int timeout, SOCKET& hS // ssize_t recv(int sockfd, void *buf, size_t len, int flags); // However Windows explicitly requires a char *buf: // int recv(SOCKET s, char *buf, int len, int flags); - ssize_t ret = recv(hSocket, reinterpret_cast(data), len, 0); + //ssize_t ret = recv(hSocket, reinterpret_cast(data), len, 0); + ssize_t ret = sock.Recv(reinterpret_cast(data), len, 0); if (ret > 0) { len -= ret; data += ret; @@ -279,14 +280,10 @@ bool static InterruptibleRecv(uint8_t* data, size_t len, int timeout, SOCKET& hS } else { // Other error or blocking int nErr = WSAGetLastError(); if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL) { - if (!IsSelectableSocket(hSocket)) { + if (!sock.IsSelectable()) { return false; } - struct timeval tval = MillisToTimeval(std::min(endTime - curTime, maxWait)); - fd_set fdset; - FD_ZERO(&fdset); - FD_SET(hSocket, &fdset); - int nRet = select(hSocket + 1, &fdset, NULL, NULL, &tval); + int nRet = sock.Wait(std::min(endTime - curTime, maxWait), Sock::RECV); if (nRet == SOCKET_ERROR) { return false; } @@ -308,11 +305,10 @@ struct ProxyCredentials }; /** Connect using SOCKS5 (as described in RFC1928) */ -static bool Socks5(const std::string& strDest, int port, const ProxyCredentials *auth, SOCKET& hSocket) +static bool Socks5(const std::string& strDest, int port, const ProxyCredentials *auth, Sock& hSocket) { LogPrintf("SOCKS5 connecting %s\n", strDest); if (strDest.size() > 255) { - CloseSocket(hSocket); return error("Hostname too long"); } // Accepted authentication methods @@ -326,18 +322,16 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials vSocks5Init.push_back(0x01); // # METHODS vSocks5Init.push_back(0x00); // X'00' NO AUTHENTICATION REQUIRED } - ssize_t ret = send(hSocket, reinterpret_cast(begin_ptr(vSocks5Init)), vSocks5Init.size(), MSG_NOSIGNAL); + //ssize_t ret = send(hSocket, reinterpret_cast(begin_ptr(vSocks5Init)), vSocks5Init.size(), MSG_NOSIGNAL); + ssize_t ret = hSocket.Send(reinterpret_cast(begin_ptr(vSocks5Init)), vSocks5Init.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vSocks5Init.size()) { - CloseSocket(hSocket); return error("Error sending to proxy"); } uint8_t pchRet1[2]; if (!InterruptibleRecv(pchRet1, 2, SOCKS5_RECV_TIMEOUT, hSocket)) { - CloseSocket(hSocket); return error("Error reading proxy response"); } if (pchRet1[0] != 0x05) { - CloseSocket(hSocket); return error("Proxy failed to initialize"); } if (pchRet1[1] == 0x02 && auth) { @@ -350,25 +344,22 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials vAuth.insert(vAuth.end(), auth->username.begin(), auth->username.end()); vAuth.push_back(auth->password.size()); vAuth.insert(vAuth.end(), auth->password.begin(), auth->password.end()); - ret = send(hSocket, reinterpret_cast(begin_ptr(vAuth)), vAuth.size(), MSG_NOSIGNAL); + //ret = send(hSocket, reinterpret_cast(begin_ptr(vAuth)), vAuth.size(), MSG_NOSIGNAL); + ret = hSocket.Send(reinterpret_cast(begin_ptr(vAuth)), vAuth.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vAuth.size()) { - CloseSocket(hSocket); return error("Error sending authentication to proxy"); } LogPrint("proxy", "SOCKS5 sending proxy authentication %s:%s\n", auth->username, auth->password); uint8_t pchRetA[2]; if (!InterruptibleRecv(pchRetA, 2, SOCKS5_RECV_TIMEOUT, hSocket)) { - CloseSocket(hSocket); return error("Error reading proxy authentication response"); } if (pchRetA[0] != 0x01 || pchRetA[1] != 0x00) { - CloseSocket(hSocket); return error("Proxy authentication unsuccessful"); } } else if (pchRet1[1] == 0x00) { // Perform no authentication } else { - CloseSocket(hSocket); return error("Proxy requested wrong authentication method %02x", pchRet1[1]); } std::vector vSocks5; @@ -380,22 +371,19 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials vSocks5.insert(vSocks5.end(), strDest.begin(), strDest.end()); vSocks5.push_back((port >> 8) & 0xFF); vSocks5.push_back((port >> 0) & 0xFF); - ret = send(hSocket, reinterpret_cast(begin_ptr(vSocks5)), vSocks5.size(), MSG_NOSIGNAL); + //ret = send(hSocket, reinterpret_cast(begin_ptr(vSocks5)), vSocks5.size(), MSG_NOSIGNAL); + ret = hSocket.Send(reinterpret_cast(begin_ptr(vSocks5)), vSocks5.size(), MSG_NOSIGNAL); if (ret != (ssize_t)vSocks5.size()) { - CloseSocket(hSocket); return error("Error sending to proxy"); } uint8_t pchRet2[4]; if (!InterruptibleRecv(pchRet2, 4, SOCKS5_RECV_TIMEOUT, hSocket)) { - CloseSocket(hSocket); return error("Error reading proxy response"); } if (pchRet2[0] != 0x05) { - CloseSocket(hSocket); return error("Proxy failed to accept request"); } if (pchRet2[1] != 0x00) { - CloseSocket(hSocket); switch (pchRet2[1]) { case 0x01: return error("Proxy error: general failure"); @@ -410,7 +398,6 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials } } if (pchRet2[2] != 0x00) { - CloseSocket(hSocket); return error("Error: malformed proxy response"); } uint8_t pchRet3[256]; @@ -422,31 +409,78 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials { ret = InterruptibleRecv(pchRet3, 1, SOCKS5_RECV_TIMEOUT, hSocket); if (!ret) { - CloseSocket(hSocket); return error("Error reading from proxy"); } size_t nRecv = pchRet3[0]; ret = InterruptibleRecv(pchRet3, nRecv, SOCKS5_RECV_TIMEOUT, hSocket); break; } - default: CloseSocket(hSocket); return error("Error: malformed proxy response"); + default: return error("Error: malformed proxy response"); } if (!ret) { - CloseSocket(hSocket); return error("Error reading from proxy"); } if (!InterruptibleRecv(pchRet3, 2, SOCKS5_RECV_TIMEOUT, hSocket)) { - CloseSocket(hSocket); return error("Error reading from proxy"); } LogPrintf("SOCKS5 connected %s\n", strDest); return true; } -bool static ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocketRet, int nTimeout) +std::unique_ptr CreateSockTCP(const CService& address_family) { - hSocketRet = INVALID_SOCKET; + // Create a sockaddr from the specified service. + struct sockaddr_storage sockaddr; + socklen_t len = sizeof(sockaddr); + if (!address_family.GetSockAddr((struct sockaddr*)&sockaddr, &len)) { + LogPrintf("Cannot create socket for %s: unsupported network\n", address_family.ToStringPort()); + return nullptr; + } + + // Create a TCP socket in the address family of the specified service. + SOCKET hSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP); + if (hSocket == INVALID_SOCKET) { + return nullptr; + } + + auto sock = std::make_unique(hSocket); + + // Ensure that waiting for I/O on this socket won't result in undefined + // behavior. + if (!sock->IsSelectable()) { + LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n"); + return nullptr; + } + +#ifdef SO_NOSIGPIPE + int set = 1; + // Set the no-sigpipe option on the socket for BSD systems, other UNIXes + // should use the MSG_NOSIGNAL flag for every send. + if (sock->SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)) == SOCKET_ERROR) { + LogPrintf("Error setting SO_NOSIGPIPE on socket: %s, continuing anyway\n", + NetworkErrorString(WSAGetLastError())); + } +#endif + + // Set the no-delay option (disable Nagle's algorithm) on the TCP socket. + const int on{1}; + if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { + LogPrint("net", "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n"); + } + // Set the non-blocking option on the socket. + if (!sock->SetNonBlocking()) { + LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError())); + return nullptr; + } + return sock; +} + +std::function(const CService&)> CreateSock = CreateSockTCP; + + +bool static ConnectSocketDirectly(const CService &addrConnect, Sock& sock, int nTimeout) +{ struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); if (!addrConnect.GetSockAddr(reinterpret_cast(&sockaddr), &len)) { @@ -454,65 +488,57 @@ bool static ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocketRe return false; } - SOCKET hSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP); - if (hSocket == INVALID_SOCKET) + if (sock.Get() == INVALID_SOCKET) return false; int set = 1; #ifdef SO_NOSIGPIPE // Different way of disabling SIGPIPE on BSD - setsockopt(hSocket, SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)); + sock.SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)); #endif //Disable Nagle's algorithm #ifdef WIN32 - setsockopt(hSocket, IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int)); + sock.SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int)); #else - setsockopt(hSocket, IPPROTO_TCP, TCP_NODELAY, (void*)&set, sizeof(int)); + sock.SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (void*)&set, sizeof(int)); #endif // Set to non-blocking - if (!SetSocketNonBlocking(hSocket, true)) + if (!sock.SetNonBlocking()) return error("ConnectSocketDirectly: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError())); - if (connect(hSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) + if (sock.Connect((struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) { int nErr = WSAGetLastError(); // WSAEINVAL is here because some legacy version of winsock uses it if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL) { - struct timeval timeout = MillisToTimeval(nTimeout); - fd_set fdset; - FD_ZERO(&fdset); - FD_SET(hSocket, &fdset); - int nRet = select(hSocket + 1, NULL, &fdset, NULL, &timeout); + //int nRet = select(hSocket + 1, NULL, &fdset, NULL, &timeout); + int nRet = sock.Wait(nTimeout, Sock::SEND); if (nRet == 0) { LogPrint("net", "connection to %s timeout\n", addrConnect.ToString()); - CloseSocket(hSocket); return false; } if (nRet == SOCKET_ERROR) { LogPrintf("select() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); - CloseSocket(hSocket); return false; } socklen_t nRetSize = sizeof(nRet); #ifdef WIN32 - if (getsockopt(hSocket, SOL_SOCKET, SO_ERROR, (char*)(&nRet), &nRetSize) == SOCKET_ERROR) + if (sock.GetSockOpt(SOL_SOCKET, SO_ERROR, (char*)(&nRet), &nRetSize) == SOCKET_ERROR) #else - if (getsockopt(hSocket, SOL_SOCKET, SO_ERROR, &nRet, &nRetSize) == SOCKET_ERROR) + if (sock.GetSockOpt(SOL_SOCKET, SO_ERROR, &nRet, &nRetSize) == SOCKET_ERROR) #endif { LogPrintf("getsockopt() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); - CloseSocket(hSocket); return false; } if (nRet != 0) { LogPrintf("connect() to %s failed after select(): %s\n", addrConnect.ToString(), NetworkErrorString(nRet)); - CloseSocket(hSocket); return false; } } @@ -523,12 +549,10 @@ bool static ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocketRe #endif { LogPrintf("connect() to %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); - CloseSocket(hSocket); return false; } } - hSocketRet = hSocket; return true; } @@ -580,11 +604,10 @@ bool IsProxy(const CNetAddr &addr) { return false; } -static bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, SOCKET& hSocketRet, int nTimeout, bool *outProxyConnectionFailed) +static bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, Sock& sock, int nTimeout, bool *outProxyConnectionFailed) { - SOCKET hSocket = INVALID_SOCKET; // first connect to proxy server - if (!ConnectSocketDirectly(proxy.proxy, hSocket, nTimeout)) { + if (!ConnectSocketDirectly(proxy.proxy, sock, nTimeout)) { if (outProxyConnectionFailed) *outProxyConnectionFailed = true; return false; @@ -594,30 +617,29 @@ static bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDe ProxyCredentials random_auth; random_auth.username = strprintf("%i", insecure_rand()); random_auth.password = strprintf("%i", insecure_rand()); - if (!Socks5(strDest, (unsigned short)port, &random_auth, hSocket)) + if (!Socks5(strDest, (unsigned short)port, &random_auth, sock)) return false; } else { - if (!Socks5(strDest, (unsigned short)port, 0, hSocket)) + if (!Socks5(strDest, (unsigned short)port, 0, sock)) return false; } - hSocketRet = hSocket; return true; } -bool ConnectSocket(const CService &addrDest, SOCKET& hSocketRet, int nTimeout, bool *outProxyConnectionFailed) +bool ConnectSocket(const CService &addrDest, Sock& sock, int nTimeout, bool *outProxyConnectionFailed) { proxyType proxy; if (outProxyConnectionFailed) *outProxyConnectionFailed = false; if (GetProxy(addrDest.GetNetwork(), proxy)) - return ConnectThroughProxy(proxy, addrDest.ToStringIP(), addrDest.GetPort(), hSocketRet, nTimeout, outProxyConnectionFailed); + return ConnectThroughProxy(proxy, addrDest.ToStringIP(), addrDest.GetPort(), sock, nTimeout, outProxyConnectionFailed); else // no proxy needed (none set for target network) - return ConnectSocketDirectly(addrDest, hSocketRet, nTimeout); + return ConnectSocketDirectly(addrDest, sock, nTimeout); } -bool ConnectSocketByName(CService &addr, SOCKET& hSocketRet, const char *pszDest, int portDefault, int nTimeout, bool *outProxyConnectionFailed) +bool ConnectSocketByName(CService &addr, Sock& sock, const char *pszDest, int portDefault, int nTimeout, bool *outProxyConnectionFailed) { std::string strDest; int port = portDefault; @@ -633,14 +655,15 @@ bool ConnectSocketByName(CService &addr, SOCKET& hSocketRet, const char *pszDest CService addrResolved(CNetAddr(strDest, fNameLookup && !HaveNameProxy()), port); if (addrResolved.IsValid()) { addr = addrResolved; - return ConnectSocket(addr, hSocketRet, nTimeout); + sock = std::move(*CreateSock(addr)); // TODO: address this in net.cpp + return ConnectSocket(addr, sock, nTimeout); } addr = CService("0.0.0.0:0"); if (!HaveNameProxy()) return false; - return ConnectThroughProxy(nameProxy, strDest, port, hSocketRet, nTimeout, outProxyConnectionFailed); + return ConnectThroughProxy(nameProxy, strDest, port, sock, nTimeout, outProxyConnectionFailed); } void CNetAddr::Init() @@ -1359,82 +1382,6 @@ bool operator<(const CSubNet& a, const CSubNet& b) return (a.network < b.network || (a.network == b.network && memcmp(a.netmask, b.netmask, 16) < 0)); } -#ifdef WIN32 -std::string NetworkErrorString(int err) -{ - char buf[256]; - buf[0] = 0; - if(FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_MAX_WIDTH_MASK, - NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - buf, sizeof(buf), NULL)) - { - return strprintf("%s (%d)", buf, err); - } - else - { - return strprintf("Unknown error (%d)", err); - } -} -#else -std::string NetworkErrorString(int err) -{ - char buf[256]; - const char *s = buf; - buf[0] = 0; - /* Too bad there are two incompatible implementations of the - * thread-safe strerror. */ -#ifdef STRERROR_R_CHAR_P /* GNU variant can return a pointer outside the passed buffer */ - s = strerror_r(err, buf, sizeof(buf)); -#else /* POSIX variant always returns message in buffer */ - if (strerror_r(err, buf, sizeof(buf))) - buf[0] = 0; -#endif - return strprintf("%s (%d)", s, err); -} -#endif - -bool CloseSocket(SOCKET& hSocket) -{ - if (hSocket == INVALID_SOCKET) - return false; -#ifdef WIN32 - int ret = closesocket(hSocket); -#else - int ret = close(hSocket); -#endif - hSocket = INVALID_SOCKET; - return ret != SOCKET_ERROR; -} - -bool SetSocketNonBlocking(SOCKET& hSocket, bool fNonBlocking) -{ - if (fNonBlocking) { -#ifdef WIN32 - u_long nOne = 1; - if (ioctlsocket(hSocket, FIONBIO, &nOne) == SOCKET_ERROR) { -#else - int fFlags = fcntl(hSocket, F_GETFL, 0); - if (fcntl(hSocket, F_SETFL, fFlags | O_NONBLOCK) == SOCKET_ERROR) { -#endif - CloseSocket(hSocket); - return false; - } - } else { -#ifdef WIN32 - u_long nZero = 0; - if (ioctlsocket(hSocket, FIONBIO, &nZero) == SOCKET_ERROR) { -#else - int fFlags = fcntl(hSocket, F_GETFL, 0); - if (fcntl(hSocket, F_SETFL, fFlags & ~O_NONBLOCK) == SOCKET_ERROR) { -#endif - CloseSocket(hSocket); - return false; - } - } - - return true; -} - void InterruptSocks5(bool interrupt) { interruptSocks5Recv = interrupt; } @@ -1442,3 +1389,4 @@ void InterruptSocks5(bool interrupt) { void InterruptLookup(bool interrupt) { interruptLookupRecv = interrupt; } + diff --git a/src/netbase.h b/src/netbase.h index 6cc46fad52..b62bdb7aa9 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -11,6 +11,7 @@ #include "compat.h" #include "serialize.h" +#include "util/sock.h" #include #include @@ -195,12 +196,23 @@ bool LookupHost(const char *pszName, std::vector& vIP, unsigned int nM bool Lookup(const char *pszName, CService& addr, int portDefault = 0, bool fAllowLookup = true); bool Lookup(const char *pszName, std::vector& vAddr, int portDefault = 0, bool fAllowLookup = true, unsigned int nMaxSolutions = 0); bool LookupNumeric(const char *pszName, CService& addr, int portDefault = 0); -bool ConnectSocket(const CService &addr, SOCKET& hSocketRet, int nTimeout, bool *outProxyConnectionFailed = 0); -bool ConnectSocketByName(CService &addr, SOCKET& hSocketRet, const char *pszDest, int portDefault, int nTimeout, bool *outProxyConnectionFailed = 0); + +/** + * Create a TCP socket in the given address family. + * @param[in] address_family The socket is created in the same address family as this address. + * @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure + */ +std::unique_ptr CreateSockTCP(const CService& address_family); + +/** + * Socket factory. Defaults to `CreateSockTCP()`, but can be overridden by unit tests. + */ +extern std::function(const CService&)> CreateSock; + +bool ConnectSocket(const CService &addr, Sock& sock, int nTimeout, bool *outProxyConnectionFailed = 0); +bool ConnectSocketByName(CService &addr, Sock& sock, const char *pszDest, int portDefault, int nTimeout, bool *outProxyConnectionFailed = 0); /** Return readable error string for a network error code */ std::string NetworkErrorString(int err); -/** Close socket and set hSocket to INVALID_SOCKET */ -bool CloseSocket(SOCKET& hSocket); /** Disable or enable blocking-mode for a socket */ bool SetSocketNonBlocking(SOCKET& hSocket, bool fNonBlocking); /** diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp index 77331fea2d..68f2aa2a03 100644 --- a/src/test/DoS_tests.cpp +++ b/src/test/DoS_tests.cpp @@ -50,7 +50,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) connman.reset(new CConnman()); CNode::ClearBanned(); CAddress addr1(ip(0xa0b0c001)); - CNode dummyNode1(INVALID_SOCKET, addr1, "", true); + CNode dummyNode1(nullptr, addr1, "", true); dummyNode1.nVersion = 1; Misbehaving(dummyNode1.GetId(), 100); // Should get banned SendMessages(&dummyNode1, false, interruptDummy); @@ -58,7 +58,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) BOOST_CHECK(!CNode::IsBanned(ip(0xa0b0c001|0x0000ff00))); // Different IP, not banned CAddress addr2(ip(0xa0b0c002)); - CNode dummyNode2(INVALID_SOCKET, addr2, "", true); + CNode dummyNode2(nullptr, addr2, "", true); dummyNode2.nVersion = 1; Misbehaving(dummyNode2.GetId(), 50); SendMessages(&dummyNode2, false, interruptDummy); @@ -76,7 +76,7 @@ BOOST_AUTO_TEST_CASE(DoS_banscore) CNode::ClearBanned(); mapArgs["-banscore"] = "111"; // because 11 is my favorite number CAddress addr1(ip(0xa0b0c001)); - CNode dummyNode1(INVALID_SOCKET, addr1, "", true); + CNode dummyNode1(nullptr, addr1, "", true); dummyNode1.nVersion = 1; Misbehaving(dummyNode1.GetId(), 100); SendMessages(&dummyNode1, false, interruptDummy); @@ -98,7 +98,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) int64_t nStartTime = GetTime(); SetMockTime(nStartTime); // Overrides future calls to GetTime() CAddress addr(ip(0xa0b0c001)); - CNode dummyNode(INVALID_SOCKET, addr, "", true); + CNode dummyNode(nullptr, addr, "", true); dummyNode.nVersion = 1; Misbehaving(dummyNode.GetId(), 100); SendMessages(&dummyNode, false, interruptDummy); diff --git a/src/util/sock.cpp b/src/util/sock.cpp new file mode 100644 index 0000000000..be0fb1a8a6 --- /dev/null +++ b/src/util/sock.cpp @@ -0,0 +1,260 @@ +#include "compat.h" +#include "tinyformat.h" +#include "sock.h" +#include "util.h" + +#include +#include +#include +#include + +#ifdef USE_POLL +#include +#endif + +#include +#include + +Sock::Sock() : m_socket(INVALID_SOCKET), m_ssl(nullptr) {} + +Sock::Sock(SOCKET s, SSL* ssl) : m_socket(s), m_ssl(ssl) {} + +Sock::Sock(Sock&& other) +{ + m_socket = other.m_socket; + m_ssl = other.m_ssl; + other.m_socket = INVALID_SOCKET; + other.m_ssl = nullptr; +} + +Sock::~Sock() { Reset(); } + +Sock& Sock::operator=(Sock&& other) +{ + Reset(); + m_socket = other.m_socket; + m_ssl = other.m_ssl; + other.m_socket = INVALID_SOCKET; + other.m_ssl = nullptr; + return *this; +} + +SOCKET Sock::Get() const { return m_socket; } + +SSL* Sock::GetSSL() const { return m_ssl; } +bool Sock::SetSSL(SSL* ssl) { + if (m_ssl) { + SSL_free(m_ssl); + } + + m_ssl = ssl; + if (!m_ssl) { + return false; + } + return SSL_set_fd(m_ssl, m_socket); +} + +/* +SOCKET Sock::Release() +{ + const SOCKET s = m_socket; + m_socket = INVALID_SOCKET; + return s; +} +*/ + +bool Sock::Reset() { return Close(); } + +ssize_t Sock::Send(const void* data, size_t len, int flags) const +{ + if (m_ssl) { + ERR_clear_error(); // clear the error queue + return SSL_write(m_ssl, static_cast(data), len); + } + return send(m_socket, static_cast(data), len, flags); +} + +ssize_t Sock::Recv(void* buf, size_t len, int flags) const +{ + if (m_ssl) { + ERR_clear_error(); // clear the error queue + return SSL_read(m_ssl, static_cast(buf), len); + } + return recv(m_socket, static_cast(buf), len, flags); +} + +int Sock::Wait(int64_t timeout, Event requested) const +{ +#ifdef USE_POLL + pollfd fd; + fd.fd = m_socket; + fd.events = 0; + if (requested & RECV) { + fd.events |= POLLIN; + } + if (requested & SEND) { + fd.events |= POLLOUT; + } + + return poll(&fd, 1, count_milliseconds(timeout)); +#else + if (!IsSelectable()) { + return -1; + } + + fd_set fdset_recv; + fd_set fdset_send; + FD_ZERO(&fdset_recv); + FD_ZERO(&fdset_send); + + if (requested & RECV) { + FD_SET(m_socket, &fdset_recv); + } + + if (requested & SEND) { + FD_SET(m_socket, &fdset_send); + } + + timeval timeout_struct = MillisToTimeval(timeout); + + return select(m_socket + 1, &fdset_recv, &fdset_send, nullptr, &timeout_struct); +#endif /* USE_POLL */ +} + + +#ifdef WIN32 +std::string NetworkErrorString(int err) +{ + wchar_t buf[256]; + buf[0] = 0; + if(FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_MAX_WIDTH_MASK, + nullptr, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + buf, ARRAYSIZE(buf), nullptr)) + { + return strprintf("%s (%d)", std::wstring_convert,wchar_t>().to_bytes(buf), err); + } + else + { + return strprintf("Unknown error (%d)", err); + } +} +#else +std::string NetworkErrorString(int err) +{ + char buf[256]; + buf[0] = 0; + /* Too bad there are two incompatible implementations of the + * thread-safe strerror. */ + const char *s; +#ifdef STRERROR_R_CHAR_P /* GNU variant can return a pointer outside the passed buffer */ + s = strerror_r(err, buf, sizeof(buf)); +#else /* POSIX variant always returns message in buffer */ + s = buf; + if (strerror_r(err, buf, sizeof(buf))) + buf[0] = 0; +#endif + return strprintf("%s (%d)", s, err); +} +#endif + +bool Sock::Close() +{ + LogPrintf("CLosing socket: %d. Error: %s\n", m_socket, NetworkErrorString(WSAGetLastError())); + if (m_ssl) { + SSL_free(m_ssl); + m_ssl = nullptr; + } + + if (m_socket == INVALID_SOCKET) + return false; +#ifdef WIN32 + int ret = closesocket(m_socket); +#else + int ret = close(m_socket); +#endif + if (ret) { + LogPrintf("Socket close failed: %d. Error: %s\n", m_socket, NetworkErrorString(WSAGetLastError())); + } + m_socket = INVALID_SOCKET; + return ret != SOCKET_ERROR; +} + +int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const +{ + return getsockopt(m_socket, level, opt_name, static_cast(opt_val), opt_len); +} + +int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const +{ + return setsockopt(m_socket, level, opt_name, static_cast(opt_val), opt_len); +} + +bool Sock::SetNonBlocking() const +{ +#ifdef WIN32 + u_long on{1}; + if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) { + return false; + } +#else + const int flags{fcntl(m_socket, F_GETFL, 0)}; + if (flags == SOCKET_ERROR) { + return false; + } + if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) { + return false; + } +#endif + return true; +} + +bool Sock::IsSelectable() const +{ +#if defined(USE_POLL) || defined(WIN32) + return true; +#else + return m_socket < FD_SETSIZE; +#endif +} + +int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const +{ + return connect(m_socket, addr, addr_len); +} + +std::unique_ptr Sock::Accept(sockaddr* addr, socklen_t* addr_len) const +{ +#ifdef WIN32 + static constexpr auto ERR = INVALID_SOCKET; +#else + static constexpr auto ERR = SOCKET_ERROR; +#endif + + std::unique_ptr sock; + + const SOCKET socket = accept(m_socket, addr, addr_len); + if (socket != ERR) { + try { + sock = std::make_unique(socket); + } catch (const std::exception&) { +#ifdef WIN32 + closesocket(socket); +#else + close(socket); +#endif + } + } + + return sock; +} + +int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const +{ + return bind(m_socket, addr, addr_len); +} + +int Sock::Listen(int backlog) const +{ + return listen(m_socket, backlog); +} + diff --git a/src/util/sock.h b/src/util/sock.h new file mode 100644 index 0000000000..e3b757ee61 --- /dev/null +++ b/src/util/sock.h @@ -0,0 +1,122 @@ +#pragma once + +#include + +#include "compat.h" + +typedef struct ssl_st SSL; + +class Sock +{ +public: + /** + * Default constructor, creates an empty object that does nothing when destroyed. + */ + Sock(); + + /** + * Take ownership of an existent socket. + */ + explicit Sock(SOCKET s, SSL* ssl = nullptr); + + /** + * Copy constructor, disabled because closing the same socket twice is undesirable. + */ + Sock(const Sock&) = delete; + + /** + * Move constructor, grab the socket from another object and close ours (if set). + */ + Sock(Sock&& other); + + /** + * Destructor, close the socket or do nothing if empty. + */ + virtual ~Sock(); + + /** + * Copy assignment operator, disabled because closing the same socket twice is undesirable. + */ + Sock& operator=(const Sock&) = delete; + + /** + * Move assignment operator, grab the socket from another object and close ours (if set). + */ + virtual Sock& operator=(Sock&& other); + + /** + * Get the value of the contained socket. + * @return socket or INVALID_SOCKET if empty + */ + virtual SOCKET Get() const; + + virtual SSL* GetSSL() const; + virtual bool SetSSL(SSL* ssl); + + /** + * Get the value of the contained socket and drop ownership. It will not be closed by the + * destructor after this call. + * @return socket or INVALID_SOCKET if empty + */ + //virtual SOCKET Release(); + + /** + * Close if non-empty. + */ + virtual bool Reset(); + + /** + * send(2) wrapper. Equivalent to `send(this->Get(), data, len, flags);`. Code that uses this + * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + */ + virtual ssize_t Send(const void* data, size_t len, int flags) const; + + /** + * recv(2) wrapper. Equivalent to `recv(this->Get(), buf, len, flags);`. Code that uses this + * wrapper can be unit-tested if this method is overridden by a mock Sock implementation. + */ + virtual ssize_t Recv(void* buf, size_t len, int flags) const; + + using Event = uint8_t; + + /** + * If passed to `Wait()`, then it will wait for readiness to read from the socket. + */ + static constexpr Event RECV = 0b01; + + /** + * If passed to `Wait()`, then it will wait for readiness to send to the socket. + */ + static constexpr Event SEND = 0b10; + + /** + * Wait for readiness for input (recv) or output (send). + * @param[in] timeout Wait this much for at least one of the requested events to occur. + * @param[in] requested Wait for those events, bitwise-or of `RECV` and `SEND`. + * @return true on success and false otherwise + */ + virtual int Wait(int64_t timeout, Event requested) const; + int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const; + int SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const; + bool SetNonBlocking() const; + bool IsSelectable() const; + int Connect(const sockaddr* addr, socklen_t addr_len) const; + std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const; + int Bind(const sockaddr* addr, socklen_t addr_len) const; + int Listen(int backlog) const; + +private: + /** + * Contained socket. `INVALID_SOCKET` designates the object is empty. + */ + SOCKET m_socket; + SSL* m_ssl; // TODO: remember to free this!!! + + /** Close socket and set hSocket to INVALID_SOCKET */ + bool Close(); +}; + +/** Return readable error string for a network error code */ +std::string NetworkErrorString(int err); + +struct timeval MillisToTimeval(int64_t nTimeout); diff --git a/src/zen/tlsmanager.cpp b/src/zen/tlsmanager.cpp index d2cdd8de9e..c82be78eda 100644 --- a/src/zen/tlsmanager.cpp +++ b/src/zen/tlsmanager.cpp @@ -183,11 +183,14 @@ int tlsCertVerificationCallback(int preverify_ok, X509_STORE_CTX* chainContext) * @param timeoutSec timeout in seconds. * @return int returns nError corresponding to the connection event. */ -int TLSManager::waitFor(SSLConnectionRoutine eRoutine, const CAddress& peerAddress, SSL* ssl, int timeoutMilliSec, unsigned long& err_code) +int TLSManager::waitFor(SSLConnectionRoutine eRoutine, const CAddress& peerAddress, Sock& sock, int timeoutMilliSec, unsigned long& err_code) { std::string eRoutine_str{}; int retOp{0}; - const SOCKET hSocket = SSL_get_fd(ssl); + const SOCKET hSocket = sock.Get(); + SSL* ssl = sock.GetSSL(); + assert(ssl); + assert(SSL_get_fd(ssl) == hSocket); err_code = 0; @@ -252,15 +255,6 @@ int TLSManager::waitFor(SSLConnectionRoutine eRoutine, const CAddress& peerAddre std::string ssl_error_str{}; int result{0}; - // select() modifies its arguments, so these must be reinitialized on each iteration - fd_set socketSet; - struct timeval timeout { - timeoutMilliSec / 1000, (timeoutMilliSec % 1000) * 1000 - }; - - FD_ZERO(&socketSet); - FD_SET(hSocket, &socketSet); - switch (sslErr) { case SSL_ERROR_SSL: // - case for shutdown sent while the peer still sending data after we've sent close_notify @@ -288,11 +282,15 @@ int TLSManager::waitFor(SSLConnectionRoutine eRoutine, const CAddress& peerAddre [[fallthrough]]; // Need to read more case SSL_ERROR_WANT_READ: ssl_error_str = "SSL_ERROR_WANT_READ"; - result = select(hSocket + 1, &socketSet, NULL, NULL, &timeout); + LogPrint("tls", "TLS: %s: %s: %s peer=%s want read more\n", + __FILE__, __func__, eRoutine_str, peerAddress.ToString()); + result = sock.Wait(timeoutMilliSec, Sock::RECV); break; case SSL_ERROR_WANT_WRITE: ssl_error_str = "SSL_ERROR_WANT_WRITE"; - result = select(hSocket + 1, NULL, &socketSet, NULL, &timeout); + LogPrint("tls", "TLS: %s: %s: %s peer=%s want send more\n", + __FILE__, __func__, eRoutine_str, peerAddress.ToString()); + result = sock.Wait(timeoutMilliSec, Sock::SEND); break; default: // For all othe errors we intentionally do fail (no retries) @@ -332,7 +330,7 @@ int TLSManager::waitFor(SSLConnectionRoutine eRoutine, const CAddress& peerAddre * @param tls_ctx_client TLS Client context * @return SSL* returns a ssl* if successful, otherwise returns NULL. */ -SSL* TLSManager::connect(SOCKET hSocket, const CAddress& addrConnect, unsigned long& err_code) +SSL* TLSManager::connect(Sock& sock, const CAddress& addrConnect, unsigned long& err_code) { LogPrint("tls", "TLS: establishing connection (tid = %X), (peerid = %s)\n", pthread_self(), addrConnect.ToString()); @@ -341,8 +339,9 @@ SSL* TLSManager::connect(SOCKET hSocket, const CAddress& addrConnect, unsigned l bool bConnectedTLS = false; if ((ssl = SSL_new(tls_ctx_client))) { - if (SSL_set_fd(ssl, hSocket)) { - int ret = TLSManager::waitFor(SSL_CONNECT, addrConnect, ssl, DEFAULT_CONNECT_TIMEOUT, err_code); + if (SSL_set_fd(ssl, sock.Get())) { + sock.SetSSL(ssl); + int ret = TLSManager::waitFor(SSL_CONNECT, addrConnect, sock, DEFAULT_CONNECT_TIMEOUT, err_code); if (ret == 1) { bConnectedTLS = true; @@ -367,7 +366,8 @@ SSL* TLSManager::connect(SOCKET hSocket, const CAddress& addrConnect, unsigned l if (ssl) { SSL_free(ssl); - ssl = NULL; + ssl = nullptr; + sock.SetSSL(ssl); } } return ssl; @@ -523,7 +523,7 @@ bool TLSManager::prepareCredentials() * @param tls_ctx_server TLS server context. * @return SSL* returns pointer to the ssl object if successful, otherwise returns NULL */ -SSL* TLSManager::accept(SOCKET hSocket, const CAddress& addr, unsigned long& err_code) +SSL* TLSManager::accept(Sock& sock, const CAddress& addr, unsigned long& err_code) { LogPrint("tls", "TLS: accepting connection from %s (tid = %X)\n", addr.ToString(), pthread_self()); @@ -532,8 +532,9 @@ SSL* TLSManager::accept(SOCKET hSocket, const CAddress& addr, unsigned long& err bool bAcceptedTLS = false; if ((ssl = SSL_new(tls_ctx_server))) { - if (SSL_set_fd(ssl, hSocket)) { - bAcceptedTLS = (TLSManager::waitFor(SSL_ACCEPT, addr, ssl, DEFAULT_CONNECT_TIMEOUT, err_code) == 1); + if (SSL_set_fd(ssl, sock.Get())) { + sock.SetSSL(ssl); + bAcceptedTLS = (TLSManager::waitFor(SSL_ACCEPT, addr, sock, DEFAULT_CONNECT_TIMEOUT, err_code) == 1); } } else @@ -559,7 +560,8 @@ SSL* TLSManager::accept(SOCKET hSocket, const CAddress& addr, unsigned long& err if (ssl) { SSL_free(ssl); - ssl = NULL; + ssl = nullptr; + sock.SetSSL(ssl); } } @@ -627,12 +629,12 @@ int TLSManager::threadSocketHandler(CNode* pnode, fd_set& fdsetRecv, fd_set& fds { LOCK(pnode->cs_hSocket); - if (pnode->hSocket == INVALID_SOCKET) + if (pnode->hSocket->Get() == INVALID_SOCKET) return -1; - recvSet = FD_ISSET(pnode->hSocket, &fdsetRecv); - sendSet = FD_ISSET(pnode->hSocket, &fdsetSend); - errorSet = FD_ISSET(pnode->hSocket, &fdsetError); + recvSet = FD_ISSET(pnode->hSocket->Get(), &fdsetRecv); + sendSet = FD_ISSET(pnode->hSocket->Get(), &fdsetSend); + errorSet = FD_ISSET(pnode->hSocket->Get(), &fdsetError); } if (recvSet || errorSet) { @@ -643,26 +645,20 @@ int TLSManager::threadSocketHandler(CNode* pnode, fd_set& fdsetRecv, fd_set& fds // maximum record size is 16kB for SSL/TLS (still valid as of 1.1.1 version) char pchBuf[0x10000]; bool bIsSSL = false; - int nBytes = 0, nRet = 0; + int nBytes = 0; { LOCK(pnode->cs_hSocket); - if (pnode->hSocket == INVALID_SOCKET) { + if (pnode->hSocket->Get() == INVALID_SOCKET) { LogPrint("tls", "Receive: connection with %s is already closed\n", pnode->addr.ToString()); return -1; } - bIsSSL = (pnode->ssl != NULL); + nBytes = pnode->hSocket->Recv(pchBuf, sizeof(pchBuf), MSG_DONTWAIT); + + bIsSSL = (pnode->hSocket->GetSSL()); - if (bIsSSL) { - ERR_clear_error(); // clear the error queue, otherwise we may be reading an old error that occurred previously in the current thread - nBytes = SSL_read(pnode->ssl, pchBuf, sizeof(pchBuf)); - nRet = SSL_get_error(pnode->ssl, nBytes); - } else { - nBytes = recv(pnode->hSocket, pchBuf, sizeof(pchBuf), MSG_DONTWAIT); - nRet = WSAGetLastError(); - } } if (nBytes > 0) { @@ -690,6 +686,7 @@ int TLSManager::threadSocketHandler(CNode* pnode, fd_set& fdsetRecv, fd_set& fds // error // if (bIsSSL) { + int nRet = SSL_get_error(pnode->hSocket->GetSSL(), nBytes); if (nRet != SSL_ERROR_WANT_READ && nRet != SSL_ERROR_WANT_WRITE) // SSL_read() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE (https://wiki.openssl.org/index.php/Manual:SSL_read(3)#NOTES) { if (!pnode->fDisconnect) @@ -707,6 +704,7 @@ int TLSManager::threadSocketHandler(CNode* pnode, fd_set& fdsetRecv, fd_set& fds MilliSleep(1); // 1 msec } } else { + int nRet = WSAGetLastError(); if (nRet != WSAEWOULDBLOCK && nRet != WSAEMSGSIZE && nRet != WSAEINTR && nRet != WSAEINPROGRESS) { if (!pnode->fDisconnect) LogPrintf("TLS: ERROR: socket recv %s\n", NetworkErrorString(nRet)); diff --git a/src/zen/tlsmanager.h b/src/zen/tlsmanager.h index 72ae642fa5..f40bb97a36 100644 --- a/src/zen/tlsmanager.h +++ b/src/zen/tlsmanager.h @@ -16,8 +16,7 @@ #else #include #endif - -using namespace std; +#include "util/sock.h" extern std::unique_ptr connman; @@ -28,17 +27,16 @@ namespace zen * @brief A class to wrap some of zen specific TLS functionalities used in the net.cpp * */ -class TLSManager +namespace TLSManager { -public: /* This is set as a custom error number which is not an error in OpenSSL protocol. A true (not null) OpenSSL error returned by ERR_get_error() consists of a library number, function code and reason code. */ static const long SELECT_TIMEDOUT = 0xFFFFFFFF; - int waitFor(SSLConnectionRoutine eRoutine, const CAddress& peerAddress, SSL* ssl, int timeoutMilliSec, unsigned long& err_code); + int waitFor(SSLConnectionRoutine eRoutine, const CAddress& peerAddress, Sock& sock, int timeoutMilliSec, unsigned long& err_code); - SSL* connect(SOCKET hSocket, const CAddress& addrConnect, unsigned long& err_code); + SSL* connect(Sock& sock, const CAddress& addrConnect, unsigned long& err_code); SSL_CTX* initCtx( TLSContextType ctxType, const boost::filesystem::path& privateKeyFile, @@ -46,8 +44,8 @@ class TLSManager const std::vector& trustedDirs); bool prepareCredentials(); - SSL* accept(SOCKET hSocket, const CAddress& addr, unsigned long& err_code); - bool isNonTLSAddr(const string& strAddr, const vector& vPool, CCriticalSection& cs); + SSL* accept(Sock& sock, const CAddress& addr, unsigned long& err_code); + bool isNonTLSAddr(const std::string& strAddr, const std::vector& vPool, CCriticalSection& cs); void cleanNonTLSPool(std::vector& vPool, CCriticalSection& cs); int threadSocketHandler(CNode* pnode, fd_set& fdsetRecv, fd_set& fdsetSend, fd_set& fdsetError); bool initialize(); From ccdaeef2ac8a555cbef63a1fd6bfc100dec61216 Mon Sep 17 00:00:00 2001 From: Daniele Rogora Date: Fri, 22 Sep 2023 10:42:10 +0200 Subject: [PATCH 2/4] Move socket creation to simplify patterns --- src/net.cpp | 26 +++++---------- src/netbase.cpp | 75 ++++++++++++++++++++++-------------------- src/netbase.h | 4 +-- src/zen/tlsmanager.cpp | 22 +++---------- 4 files changed, 55 insertions(+), 72 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 7b78d2cb74..d9725acc5e 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -354,14 +354,12 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest) pszDest ? 0.0 : (double)(GetTime() - addrConnect.nTime)/3600.0); // Connect - std::unique_ptr sock = CreateSock(addrConnect); - if (!sock) { - return nullptr; - } bool proxyConnectionFailed = false; - if ( - pszDest ? ConnectSocketByName(addrConnect, *sock, pszDest, Params().GetDefaultPort(), nConnectTimeout, &proxyConnectionFailed) : - ConnectSocket(addrConnect, *sock, nConnectTimeout, &proxyConnectionFailed)) + // sock is actually created only by ConnectSocketDirectly() + std::unique_ptr sock = + pszDest ? ConnectSocketByName(addrConnect, pszDest, Params().GetDefaultPort(), nConnectTimeout, &proxyConnectionFailed) : + ConnectSocket(addrConnect, nConnectTimeout, &proxyConnectionFailed); + if (sock) { if (!sock->IsSelectable()) { LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n"); @@ -656,7 +654,7 @@ void CNode::copyStats(CNodeStats &stats) // If ssl != NULL it means TLS connection was established successfully { LOCK(cs_hSocket); - SSL* ssl = hSocket->GetSSL(); + SSL* ssl = hSocket ? hSocket->GetSSL() : nullptr; stats.fTLSEstablished = (ssl != nullptr) && (SSL_get_state(ssl) == TLS_ST_OK); stats.fTLSVerified = (ssl != nullptr) && ValidatePeerCertificate(ssl); } @@ -1099,8 +1097,6 @@ bool CConnman::AttemptToEvictConnection(bool fPreferNewConnection) { void CConnman::AcceptConnection(ListenSocket& hListenSocket) { struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); - //SOCKET hSocket = accept(hListenSocket.sock->Get(), (struct sockaddr*)&sockaddr, &len); - //hListenSocket.sock.reset(new Sock(accept(hListenSocket.sock->Get(), (struct sockaddr*)&sockaddr, &len))); std::unique_ptr sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len); CAddress addr; int nInbound = 0; @@ -1888,7 +1884,6 @@ bool CConnman::BindListenPort(const CService &addrBind, string& strError, bool f return false; } - //SOCKET hListenSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP); std::unique_ptr sock = CreateSock(addrBind); if (!sock) { @@ -2441,13 +2436,10 @@ CNode::~CNode() // No need to make a lock on cs_hSocket, because before deletion CNode object is removed from the vNodes vector, so any other thread hasn't access to it. // Removal is synchronized with read and write routines, so all of them will be completed to this moment. - if (hSocket) + if (hSocket && GetSSL()) { - if (GetSSL()) - { - unsigned long err_code = 0; - TLSManager::waitFor(SSL_SHUTDOWN, addr, *hSocket, 0 /*no retries here make no sense on destructor*/, err_code); - } + unsigned long err_code = 0; + TLSManager::waitFor(SSL_SHUTDOWN, addr, *hSocket, 0 /*no retries here make no sense on destructor*/, err_code); } if (pfilter) diff --git a/src/netbase.cpp b/src/netbase.cpp index b900eb5c85..ca99de7ae3 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -479,67 +479,70 @@ std::unique_ptr CreateSockTCP(const CService& address_family) std::function(const CService&)> CreateSock = CreateSockTCP; -bool static ConnectSocketDirectly(const CService &addrConnect, Sock& sock, int nTimeout) +std::unique_ptr static ConnectSocketDirectly(const CService &addrConnect, int nTimeout) { struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); if (!addrConnect.GetSockAddr(reinterpret_cast(&sockaddr), &len)) { LogPrintf("Cannot connect to %s: unsupported network\n", addrConnect.ToString()); - return false; + return nullptr; } - if (sock.Get() == INVALID_SOCKET) - return false; + std::unique_ptr sock = CreateSock(addrConnect); + + if (!sock) + return nullptr; int set = 1; #ifdef SO_NOSIGPIPE // Different way of disabling SIGPIPE on BSD - sock.SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)); + sock->SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)); #endif //Disable Nagle's algorithm #ifdef WIN32 - sock.SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int)); + sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int)); #else - sock.SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (void*)&set, sizeof(int)); + sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (void*)&set, sizeof(int)); #endif // Set to non-blocking - if (!sock.SetNonBlocking()) - return error("ConnectSocketDirectly: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError())); + if (!sock->SetNonBlocking()) { + LogPrint("net", "ConnectSocketDirectly: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError())); + return nullptr; + } - if (sock.Connect((struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) + if (sock->Connect((struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) { int nErr = WSAGetLastError(); // WSAEINVAL is here because some legacy version of winsock uses it if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL) { - //int nRet = select(hSocket + 1, NULL, &fdset, NULL, &timeout); - int nRet = sock.Wait(nTimeout, Sock::SEND); + int nRet = sock->Wait(nTimeout, Sock::SEND); if (nRet == 0) { LogPrint("net", "connection to %s timeout\n", addrConnect.ToString()); - return false; + return nullptr; } if (nRet == SOCKET_ERROR) { LogPrintf("select() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); - return false; + return nullptr; } socklen_t nRetSize = sizeof(nRet); #ifdef WIN32 - if (sock.GetSockOpt(SOL_SOCKET, SO_ERROR, (char*)(&nRet), &nRetSize) == SOCKET_ERROR) + if (sock->GetSockOpt(SOL_SOCKET, SO_ERROR, (char*)(&nRet), &nRetSize) == SOCKET_ERROR) #else - if (sock.GetSockOpt(SOL_SOCKET, SO_ERROR, &nRet, &nRetSize) == SOCKET_ERROR) + if (sock->GetSockOpt(SOL_SOCKET, SO_ERROR, &nRet, &nRetSize) == SOCKET_ERROR) #endif { LogPrintf("getsockopt() for %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); - return false; + return nullptr; } if (nRet != 0) { LogPrintf("connect() to %s failed after select(): %s\n", addrConnect.ToString(), NetworkErrorString(nRet)); - return false; + return nullptr; } } #ifdef WIN32 @@ -549,11 +552,11 @@ bool static ConnectSocketDirectly(const CService &addrConnect, Sock& sock, int n #endif { LogPrintf("connect() to %s failed: %s\n", addrConnect.ToString(), NetworkErrorString(WSAGetLastError())); - return false; + return nullptr; } } - return true; + return sock; } bool SetProxy(enum Network net, const proxyType &addrProxy) { @@ -604,42 +607,43 @@ bool IsProxy(const CNetAddr &addr) { return false; } -static bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, Sock& sock, int nTimeout, bool *outProxyConnectionFailed) +static std::unique_ptr ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, int nTimeout, bool *outProxyConnectionFailed) { + std::unique_ptr sock = ConnectSocketDirectly(proxy.proxy, nTimeout); // first connect to proxy server - if (!ConnectSocketDirectly(proxy.proxy, sock, nTimeout)) { + if (!sock) { if (outProxyConnectionFailed) *outProxyConnectionFailed = true; - return false; + return nullptr; } // do socks negotiation if (proxy.randomize_credentials) { ProxyCredentials random_auth; random_auth.username = strprintf("%i", insecure_rand()); random_auth.password = strprintf("%i", insecure_rand()); - if (!Socks5(strDest, (unsigned short)port, &random_auth, sock)) - return false; + if (!Socks5(strDest, (unsigned short)port, &random_auth, *sock)) + return nullptr; } else { - if (!Socks5(strDest, (unsigned short)port, 0, sock)) - return false; + if (!Socks5(strDest, (unsigned short)port, 0, *sock)) + return nullptr; } - return true; + return sock; } -bool ConnectSocket(const CService &addrDest, Sock& sock, int nTimeout, bool *outProxyConnectionFailed) +std::unique_ptr ConnectSocket(const CService &addrDest, int nTimeout, bool *outProxyConnectionFailed) { proxyType proxy; if (outProxyConnectionFailed) *outProxyConnectionFailed = false; if (GetProxy(addrDest.GetNetwork(), proxy)) - return ConnectThroughProxy(proxy, addrDest.ToStringIP(), addrDest.GetPort(), sock, nTimeout, outProxyConnectionFailed); + return ConnectThroughProxy(proxy, addrDest.ToStringIP(), addrDest.GetPort(), nTimeout, outProxyConnectionFailed); else // no proxy needed (none set for target network) - return ConnectSocketDirectly(addrDest, sock, nTimeout); + return ConnectSocketDirectly(addrDest, nTimeout); } -bool ConnectSocketByName(CService &addr, Sock& sock, const char *pszDest, int portDefault, int nTimeout, bool *outProxyConnectionFailed) +std::unique_ptr ConnectSocketByName(CService &addr, const char *pszDest, int portDefault, int nTimeout, bool *outProxyConnectionFailed) { std::string strDest; int port = portDefault; @@ -655,15 +659,14 @@ bool ConnectSocketByName(CService &addr, Sock& sock, const char *pszDest, int po CService addrResolved(CNetAddr(strDest, fNameLookup && !HaveNameProxy()), port); if (addrResolved.IsValid()) { addr = addrResolved; - sock = std::move(*CreateSock(addr)); // TODO: address this in net.cpp - return ConnectSocket(addr, sock, nTimeout); + return ConnectSocket(addr, nTimeout); } addr = CService("0.0.0.0:0"); if (!HaveNameProxy()) - return false; - return ConnectThroughProxy(nameProxy, strDest, port, sock, nTimeout, outProxyConnectionFailed); + return nullptr; + return ConnectThroughProxy(nameProxy, strDest, port, nTimeout, outProxyConnectionFailed); } void CNetAddr::Init() diff --git a/src/netbase.h b/src/netbase.h index b62bdb7aa9..46e94a4fa7 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -209,8 +209,8 @@ std::unique_ptr CreateSockTCP(const CService& address_family); */ extern std::function(const CService&)> CreateSock; -bool ConnectSocket(const CService &addr, Sock& sock, int nTimeout, bool *outProxyConnectionFailed = 0); -bool ConnectSocketByName(CService &addr, Sock& sock, const char *pszDest, int portDefault, int nTimeout, bool *outProxyConnectionFailed = 0); +std::unique_ptr ConnectSocket(const CService &addr, int nTimeout, bool *outProxyConnectionFailed = 0); +std::unique_ptr ConnectSocketByName(CService &addr, const char *pszDest, int portDefault, int nTimeout, bool *outProxyConnectionFailed = 0); /** Return readable error string for a network error code */ std::string NetworkErrorString(int err); /** Disable or enable blocking-mode for a socket */ diff --git a/src/zen/tlsmanager.cpp b/src/zen/tlsmanager.cpp index c82be78eda..2fbdb6200f 100644 --- a/src/zen/tlsmanager.cpp +++ b/src/zen/tlsmanager.cpp @@ -282,14 +282,10 @@ int TLSManager::waitFor(SSLConnectionRoutine eRoutine, const CAddress& peerAddre [[fallthrough]]; // Need to read more case SSL_ERROR_WANT_READ: ssl_error_str = "SSL_ERROR_WANT_READ"; - LogPrint("tls", "TLS: %s: %s: %s peer=%s want read more\n", - __FILE__, __func__, eRoutine_str, peerAddress.ToString()); result = sock.Wait(timeoutMilliSec, Sock::RECV); break; case SSL_ERROR_WANT_WRITE: ssl_error_str = "SSL_ERROR_WANT_WRITE"; - LogPrint("tls", "TLS: %s: %s: %s peer=%s want send more\n", - __FILE__, __func__, eRoutine_str, peerAddress.ToString()); result = sock.Wait(timeoutMilliSec, Sock::SEND); break; default: @@ -339,13 +335,8 @@ SSL* TLSManager::connect(Sock& sock, const CAddress& addrConnect, unsigned long& bool bConnectedTLS = false; if ((ssl = SSL_new(tls_ctx_client))) { - if (SSL_set_fd(ssl, sock.Get())) { - sock.SetSSL(ssl); - int ret = TLSManager::waitFor(SSL_CONNECT, addrConnect, sock, DEFAULT_CONNECT_TIMEOUT, err_code); - if (ret == 1) - { - bConnectedTLS = true; - } + if (sock.SetSSL(ssl)) { + bConnectedTLS = (TLSManager::waitFor(SSL_CONNECT, addrConnect, sock, DEFAULT_CONNECT_TIMEOUT, err_code) == 1); } } else @@ -365,9 +356,8 @@ SSL* TLSManager::connect(Sock& sock, const CAddress& addrConnect, unsigned long& __FILE__, __func__, __LINE__, addrConnect.ToString(), err_code); if (ssl) { - SSL_free(ssl); ssl = nullptr; - sock.SetSSL(ssl); + sock.SetSSL(nullptr); } } return ssl; @@ -532,8 +522,7 @@ SSL* TLSManager::accept(Sock& sock, const CAddress& addr, unsigned long& err_cod bool bAcceptedTLS = false; if ((ssl = SSL_new(tls_ctx_server))) { - if (SSL_set_fd(ssl, sock.Get())) { - sock.SetSSL(ssl); + if (sock.SetSSL(ssl)) { bAcceptedTLS = (TLSManager::waitFor(SSL_ACCEPT, addr, sock, DEFAULT_CONNECT_TIMEOUT, err_code) == 1); } } @@ -559,9 +548,8 @@ SSL* TLSManager::accept(Sock& sock, const CAddress& addr, unsigned long& err_cod __FILE__, __func__, __LINE__, addr.ToString(), err_code); if (ssl) { - SSL_free(ssl); ssl = nullptr; - sock.SetSSL(ssl); + sock.SetSSL(nullptr); } } From 113ed9ec30a93f4d153df68cc7f1eaa8ec4dd23c Mon Sep 17 00:00:00 2001 From: Daniele Rogora Date: Fri, 22 Sep 2023 18:43:20 +0200 Subject: [PATCH 3/4] Add unit tests for Sock --- src/Makefile.gtest.include | 1 + src/compat.h | 8 +++ src/gtest/test_sock.cpp | 120 +++++++++++++++++++++++++++++++++++++ src/net.cpp | 24 ++------ src/netbase.cpp | 8 +-- 5 files changed, 137 insertions(+), 24 deletions(-) create mode 100644 src/gtest/test_sock.cpp diff --git a/src/Makefile.gtest.include b/src/Makefile.gtest.include index 80269febee..96806bce7c 100644 --- a/src/Makefile.gtest.include +++ b/src/Makefile.gtest.include @@ -41,6 +41,7 @@ zen_gtest_SOURCES += \ gtest/test_pow.cpp \ gtest/test_random.cpp \ gtest/test_rpc.cpp \ + gtest/test_sock.cpp \ gtest/test_getblocktemplate.cpp \ gtest/test_timedata.cpp \ gtest/test_transaction.cpp \ diff --git a/src/compat.h b/src/compat.h index feaa544e25..5621c3d14d 100644 --- a/src/compat.h +++ b/src/compat.h @@ -89,6 +89,14 @@ typedef u_int SOCKET; #define THREAD_PRIORITY_ABOVE_NORMAL (-2) #endif +// The type of the option value passed to getsockopt & setsockopt +// differs between Windows and non-Windows. +#ifndef WIN32 +typedef void* sockopt_arg_type; +#else +typedef char* sockopt_arg_type; +#endif + #if HAVE_DECL_STRNLEN == 0 size_t strnlen( const char *start, size_t max_len); #endif // HAVE_DECL_STRNLEN diff --git a/src/gtest/test_sock.cpp b/src/gtest/test_sock.cpp new file mode 100644 index 0000000000..9fb3c0202b --- /dev/null +++ b/src/gtest/test_sock.cpp @@ -0,0 +1,120 @@ +#include +#include + +#include "compat.h" +#include "util/sock.h" + +static bool SocketIsClosed(const SOCKET& s) +{ + // Notice that if another thread is running and creates its own socket after `s` has been + // closed, it may be assigned the same file descriptor number. In this case, our test will + // wrongly pretend that the socket is not closed. + int type; + socklen_t len = sizeof(type); + return getsockopt(s, SOL_SOCKET, SO_TYPE, (sockopt_arg_type)&type, &len) == SOCKET_ERROR; +} + +static SOCKET CreateSocket() +{ + const SOCKET s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + assert(s != static_cast(SOCKET_ERROR)); + return s; +} + + +TEST(Sock, ConstructorDestructor) { + SOCKET s = CreateSocket(); + { + Sock sock(s); + ASSERT_EQ(sock.Get(), s); + ASSERT_FALSE(SocketIsClosed(s)); + } + ASSERT_TRUE(SocketIsClosed(s)); +} + +TEST(Sock, MoveConstructor) { + SOCKET s = CreateSocket(); + Sock sock(s); + ASSERT_EQ(sock.Get(), s); + ASSERT_FALSE(SocketIsClosed(s)); + + Sock sock2(std::move(sock)); + ASSERT_EQ(sock.Get(), INVALID_SOCKET); + ASSERT_EQ(sock2.Get(), s); + ASSERT_FALSE(SocketIsClosed(s)); +} + +TEST(Sock, MoveAssignment) { + SOCKET s = CreateSocket(); + Sock sock(s); + ASSERT_EQ(sock.Get(), s); + ASSERT_FALSE(SocketIsClosed(s)); + + Sock sock2 = std::move(sock); + ASSERT_EQ(sock.Get(), INVALID_SOCKET); + ASSERT_EQ(sock2.Get(), s); + ASSERT_FALSE(SocketIsClosed(s)); +} + +TEST(Sock, Reset) { + const SOCKET s = CreateSocket(); + Sock sock(s); + ASSERT_FALSE(SocketIsClosed(s)); + sock.Reset(); + ASSERT_TRUE(SocketIsClosed(s)); +} + +#ifndef WIN32 // Windows does not have socketpair(2). + +static void CreateSocketPair(int s[2]) { + assert(socketpair(AF_UNIX, SOCK_STREAM, 0, s) == 0); +} + +static void SendAndRecvMessage(const Sock& sender, const Sock& receiver) { + const char* msg = "abcd"; + constexpr size_t msg_len = 4; + char recv_buf[10]; + + ASSERT_EQ(sender.Send(msg, msg_len, 0), msg_len); + ASSERT_EQ(receiver.Recv(recv_buf, sizeof(recv_buf), 0), msg_len); + ASSERT_EQ(strncmp(msg, recv_buf, msg_len), 0); +} + +TEST(Sock, SendAndReceive) { + int s[2]; + CreateSocketPair(s); + + { + Sock sock0(s[0]); + Sock sock1(s[1]); + + SendAndRecvMessage(sock0, sock1); + + Sock sock0moved = std::move(sock0); + Sock sock1moved = std::move(sock1); + + SendAndRecvMessage(sock1moved, sock0moved); + } + + ASSERT_TRUE(SocketIsClosed(s[0])); + ASSERT_TRUE(SocketIsClosed(s[1])); +} + +TEST(Sock, Wait) +{ + int s[2]; + CreateSocketPair(s); + + Sock sock0(s[0]); + Sock sock1(s[1]); + + constexpr int64_t millis_in_day = 24 * 60 * 60 * 1000; + std::thread waiter([&sock0]() { ASSERT_EQ(sock0.Wait(millis_in_day, Sock::RECV), 1); }); + + ASSERT_EQ(sock1.Send("a", 1, 0), 1); + + waiter.join(); +} + +#endif /* WIN32 */ + diff --git a/src/net.cpp b/src/net.cpp index d9725acc5e..3589f9a2e6 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -12,6 +12,7 @@ #include "addrman.h" #include "chainparams.h" #include "clientversion.h" +#include "compat.h" #include "primitives/transaction.h" #include "scheduler.h" #include "ui_interface.h" @@ -1146,11 +1147,7 @@ void CConnman::AcceptConnection(ListenSocket& hListenSocket) { // According to the internet TCP_NODELAY is not carried into accepted sockets // on all platforms. Set it again here just to be sure. int set = 1; -#ifdef WIN32 - sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int)); -#else - sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (void*)&set, sizeof(int)); -#endif + sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (sockopt_arg_type)&set, sizeof(int)); SSL *ssl = nullptr; @@ -1898,20 +1895,15 @@ bool CConnman::BindListenPort(const CService &addrBind, string& strError, bool f return false; } -#ifndef WIN32 #ifdef SO_NOSIGPIPE // Different way of disabling SIGPIPE on BSD sock->SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&nOne, sizeof(int)); #endif // Allow binding if the port is still in TIME_WAIT state after // the program was closed and restarted. - sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (void*)&nOne, sizeof(int)); + sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)); // Disable Nagle's algorithm - sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (void*)&nOne, sizeof(int)); -#else - sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (const char*)&nOne, sizeof(int)); - sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (const char*)&nOne, sizeof(int)); -#endif + sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (sockopt_arg_type)&nOne, sizeof(int)); // Set to non-blocking, incoming connections will also inherit this // @@ -1929,15 +1921,11 @@ bool CConnman::BindListenPort(const CService &addrBind, string& strError, bool f // and enable it by default or not. Try to enable it, if possible. if (addrBind.IsIPv6()) { #ifdef IPV6_V6ONLY -#ifdef WIN32 - sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&nOne, sizeof(int)); -#else - sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (void*)&nOne, sizeof(int)); -#endif + sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)); #endif #ifdef WIN32 int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; - sock->SetSockOpt(IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); + sock->SetSockOpt(IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (sockopt_arg_type)&nProtLevel, sizeof(int)); #endif } diff --git a/src/netbase.cpp b/src/netbase.cpp index ca99de7ae3..8cd312f020 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -464,7 +464,7 @@ std::unique_ptr CreateSockTCP(const CService& address_family) // Set the no-delay option (disable Nagle's algorithm) on the TCP socket. const int on{1}; - if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { + if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (sockopt_arg_type)&on, sizeof(on)) == SOCKET_ERROR) { LogPrint("net", "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n"); } @@ -500,11 +500,7 @@ std::unique_ptr static ConnectSocketDirectly(const CService &addrConnect, #endif //Disable Nagle's algorithm -#ifdef WIN32 - sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int)); -#else - sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (void*)&set, sizeof(int)); -#endif + sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, (sockopt_arg_type)&set, sizeof(int)); // Set to non-blocking if (!sock->SetNonBlocking()) { From 03e38e677019a41460abde96c70a0fbee310481d Mon Sep 17 00:00:00 2001 From: Daniele Rogora Date: Tue, 26 Sep 2023 09:14:53 +0200 Subject: [PATCH 4/4] Use poll() on Linux --- src/compat.h | 7 +++ src/net.cpp | 40 ++++++----------- src/util/sock.cpp | 100 +++++++++++++++++++++++++++++++---------- src/util/sock.h | 21 +++++++-- src/zen/tlsmanager.cpp | 12 +++-- src/zen/tlsmanager.h | 2 +- 6 files changed, 124 insertions(+), 58 deletions(-) diff --git a/src/compat.h b/src/compat.h index 5621c3d14d..20c52e63a9 100644 --- a/src/compat.h +++ b/src/compat.h @@ -109,4 +109,11 @@ bool static inline IsSelectableSocket(SOCKET s) { #endif } +// Note these both should work with the current usage of poll, but best to be safe +// WIN32 poll is broken https://daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/ +// __APPLE__ poll is broke https://github.com/bitcoin/bitcoin/pull/14336#issuecomment-437384408 +#if defined(__linux__) +#define USE_POLL +#endif + #endif // BITCOIN_COMPAT_H diff --git a/src/net.cpp b/src/net.cpp index 3589f9a2e6..c7cbefda7e 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -19,6 +19,7 @@ #include "crypto/common.h" #include "zen/utiltls.h" #include "zen/tlsmanager.h" +#include "util/sock.h" #ifdef WIN32 #include @@ -1315,22 +1316,13 @@ void CConnman::ThreadSocketHandler() // // Find which sockets have data to receive // - struct timeval timeout; - timeout.tv_sec = 0; - timeout.tv_usec = 50000; // frequency to poll pnode->vSend - - fd_set fdsetRecv; - fd_set fdsetSend; - fd_set fdsetError; - FD_ZERO(&fdsetRecv); - FD_ZERO(&fdsetSend); - FD_ZERO(&fdsetError); - SOCKET hSocketMax = 0; + constexpr int64_t timeout_ms = 50; + + std::unordered_map ev_to_monitor; bool have_fds = false; BOOST_FOREACH(const ListenSocket& hListenSocket, connman->vhListenSocket) { - FD_SET(hListenSocket.sock->Get(), &fdsetRecv); - hSocketMax = max(hSocketMax, hListenSocket.sock->Get()); + ev_to_monitor[hListenSocket.sock->Get()].requested |= Sock::RECV; have_fds = true; } @@ -1343,8 +1335,7 @@ void CConnman::ThreadSocketHandler() SOCKET socket = pnode->GetSocketFd(); if (socket == INVALID_SOCKET) continue; - FD_SET(socket, &fdsetError); - hSocketMax = max(hSocketMax, socket); + ev_to_monitor[socket].requested |= Sock::ERR; have_fds = true; // Implement the following logic: @@ -1366,7 +1357,7 @@ void CConnman::ThreadSocketHandler() { TRY_LOCK(pnode->cs_vSend, lockSend); if (lockSend && !pnode->vSendMsg.empty()) { - FD_SET(socket, &fdsetSend); + ev_to_monitor[socket].requested |= Sock::SEND; continue; } } @@ -1375,13 +1366,12 @@ void CConnman::ThreadSocketHandler() if (lockRecv && ( pnode->vRecvMsg.empty() || !pnode->vRecvMsg.front().complete() || pnode->GetTotalRecvSize() <= GetReceiveFloodSize())) - FD_SET(socket, &fdsetRecv); + ev_to_monitor[socket].requested |= Sock::RECV; } } } - int nSelect = select(have_fds ? hSocketMax + 1 : 0, - &fdsetRecv, &fdsetSend, &fdsetError, &timeout); + int nSelect = Sock::WaitMany(timeout_ms, ev_to_monitor); if (interruptNet) return; @@ -1391,21 +1381,17 @@ void CConnman::ThreadSocketHandler() { int nErr = WSAGetLastError(); LogPrintf("socket select error %s\n", NetworkErrorString(nErr)); - for (unsigned int i = 0; i <= hSocketMax; i++) - FD_SET(i, &fdsetRecv); } - FD_ZERO(&fdsetSend); - FD_ZERO(&fdsetError); - if (!interruptNet.sleep_for(std::chrono::microseconds(timeout.tv_usec))) + if (!interruptNet.sleep_for(std::chrono::milliseconds(timeout_ms))) return; } // // Accept new connections // - BOOST_FOREACH(const ListenSocket& hListenSocket, vhListenSocket) + BOOST_FOREACH(ListenSocket& hListenSocket, vhListenSocket) { - if (hListenSocket.sock->Get() != INVALID_SOCKET && FD_ISSET(hListenSocket.sock->Get(), &fdsetRecv)) + if (hListenSocket.sock->Get() != INVALID_SOCKET && ev_to_monitor.at(hListenSocket.sock->Get()).occurred & Sock::RECV) { AcceptConnection(hListenSocket); } @@ -1426,7 +1412,7 @@ void CConnman::ThreadSocketHandler() if (interruptNet) return; - if (TLSManager::threadSocketHandler(pnode,fdsetRecv,fdsetSend,fdsetError)==-1){ + if (TLSManager::threadSocketHandler(pnode, ev_to_monitor) == -1){ continue; } diff --git a/src/util/sock.cpp b/src/util/sock.cpp index be0fb1a8a6..0c8882c38c 100644 --- a/src/util/sock.cpp +++ b/src/util/sock.cpp @@ -83,45 +83,100 @@ ssize_t Sock::Recv(void* buf, size_t len, int flags) const return recv(m_socket, static_cast(buf), len, flags); } -int Sock::Wait(int64_t timeout, Event requested) const + +int Sock::WaitMany(int64_t timeout, std::unordered_map& events_per_sock) { #ifdef USE_POLL - pollfd fd; - fd.fd = m_socket; - fd.events = 0; - if (requested & RECV) { - fd.events |= POLLIN; + std::vector pfds; + for (auto& [sock, events] : events_per_sock) { + pfds.emplace_back(); + auto& pfd = pfds.back(); + pfd.fd = sock; + if (events.requested & RECV) { + pfd.events |= POLLIN; + } + if (events.requested & SEND) { + pfd.events |= POLLOUT; + } } - if (requested & SEND) { - fd.events |= POLLOUT; + + int ret = poll(pfds.data(), pfds.size(), timeout); + if (ret == SOCKET_ERROR) { + return -1; } - return poll(&fd, 1, count_milliseconds(timeout)); + assert(pfds.size() == events_per_sock.size()); + size_t i{0}; + for (auto& [sock, events] : events_per_sock) { + assert(sock == static_cast(pfds[i].fd)); + events.occurred = 0; + if (pfds[i].revents & POLLIN) { + events.occurred |= RECV; + } + if (pfds[i].revents & POLLOUT) { + events.occurred |= SEND; + } + if (pfds[i].revents & (POLLERR | POLLHUP)) { + events.occurred |= ERR; + } + ++i; + } + return ret; #else - if (!IsSelectable()) { - return -1; + fd_set recv; + fd_set send; + fd_set err; + FD_ZERO(&recv); + FD_ZERO(&send); + FD_ZERO(&err); + SOCKET socket_max{0}; + + for (const auto& [s, events] : events_per_sock) { + if (s >= FD_SETSIZE) { + return false; + } + if (events.requested & RECV) { + FD_SET(s, &recv); + } + if (events.requested & SEND) { + FD_SET(s, &send); + } + FD_SET(s, &err); + socket_max = std::max(socket_max, s); } - fd_set fdset_recv; - fd_set fdset_send; - FD_ZERO(&fdset_recv); - FD_ZERO(&fdset_send); + timeval tv = MillisToTimeval(timeout); - if (requested & RECV) { - FD_SET(m_socket, &fdset_recv); + int ret = select(socket_max + 1, &recv, &send, &err, &tv); + if (ret == SOCKET_ERROR) { + return -1; } - if (requested & SEND) { - FD_SET(m_socket, &fdset_send); + for (auto& [s, events] : events_per_sock) { + events.occurred = 0; + if (FD_ISSET(s, &recv)) { + events.occurred |= RECV; + } + if (FD_ISSET(s, &send)) { + events.occurred |= SEND; + } + if (FD_ISSET(s, &err)) { + events.occurred |= ERR; + } } - timeval timeout_struct = MillisToTimeval(timeout); - - return select(m_socket + 1, &fdset_recv, &fdset_send, nullptr, &timeout_struct); + return ret; #endif /* USE_POLL */ } +int Sock::Wait(int64_t timeout, Event requested) const +{ + std::unordered_map events_per_sock = { {m_socket, Events(requested)} }; + return WaitMany(timeout, events_per_sock); +} + + #ifdef WIN32 std::string NetworkErrorString(int err) { @@ -159,7 +214,6 @@ std::string NetworkErrorString(int err) bool Sock::Close() { - LogPrintf("CLosing socket: %d. Error: %s\n", m_socket, NetworkErrorString(WSAGetLastError())); if (m_ssl) { SSL_free(m_ssl); m_ssl = nullptr; diff --git a/src/util/sock.h b/src/util/sock.h index e3b757ee61..df4b21a6a9 100644 --- a/src/util/sock.h +++ b/src/util/sock.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "compat.h" @@ -82,12 +83,25 @@ class Sock /** * If passed to `Wait()`, then it will wait for readiness to read from the socket. */ - static constexpr Event RECV = 0b01; + static constexpr Event RECV = 0b001; /** * If passed to `Wait()`, then it will wait for readiness to send to the socket. */ - static constexpr Event SEND = 0b10; + static constexpr Event SEND = 0b010; + + /** + * Ignored if passed to `Wait()`, but could be set in the occurred events if an + * exceptional condition has occurred on the socket or if it has been disconnected. + */ + static constexpr Event ERR = 0b100; + + struct Events { + explicit Events() : requested{0} {} + explicit Events(Event req) : requested{req} {} + Event requested; + Event occurred{0}; + }; /** * Wait for readiness for input (recv) or output (send). @@ -96,6 +110,7 @@ class Sock * @return true on success and false otherwise */ virtual int Wait(int64_t timeout, Event requested) const; + static int WaitMany(int64_t timeout, std::unordered_map& events_per_sock); int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const; int SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const; bool SetNonBlocking() const; @@ -110,7 +125,7 @@ class Sock * Contained socket. `INVALID_SOCKET` designates the object is empty. */ SOCKET m_socket; - SSL* m_ssl; // TODO: remember to free this!!! + SSL* m_ssl; /** Close socket and set hSocket to INVALID_SOCKET */ bool Close(); diff --git a/src/zen/tlsmanager.cpp b/src/zen/tlsmanager.cpp index 2fbdb6200f..e260a8893b 100644 --- a/src/zen/tlsmanager.cpp +++ b/src/zen/tlsmanager.cpp @@ -607,7 +607,7 @@ void TLSManager::cleanNonTLSPool(std::vector& vPool, CCriticalSection * @param fdsetError * @return int returns -1 when socket is invalid. returns 0 otherwise. */ -int TLSManager::threadSocketHandler(CNode* pnode, fd_set& fdsetRecv, fd_set& fdsetSend, fd_set& fdsetError) +int TLSManager::threadSocketHandler(CNode* pnode, const std::unordered_map& events) { // // Receive @@ -620,9 +620,13 @@ int TLSManager::threadSocketHandler(CNode* pnode, fd_set& fdsetRecv, fd_set& fds if (pnode->hSocket->Get() == INVALID_SOCKET) return -1; - recvSet = FD_ISSET(pnode->hSocket->Get(), &fdsetRecv); - sendSet = FD_ISSET(pnode->hSocket->Get(), &fdsetSend); - errorSet = FD_ISSET(pnode->hSocket->Get(), &fdsetError); + const auto& it = events.find(pnode->hSocket->Get()); + if (it == events.end()) { + return 0; + } + recvSet = it->second.occurred & Sock::RECV; + sendSet = it->second.occurred & Sock::SEND; + errorSet = it->second.occurred & Sock::ERR; } if (recvSet || errorSet) { diff --git a/src/zen/tlsmanager.h b/src/zen/tlsmanager.h index f40bb97a36..50638f20f4 100644 --- a/src/zen/tlsmanager.h +++ b/src/zen/tlsmanager.h @@ -47,7 +47,7 @@ namespace TLSManager SSL* accept(Sock& sock, const CAddress& addr, unsigned long& err_code); bool isNonTLSAddr(const std::string& strAddr, const std::vector& vPool, CCriticalSection& cs); void cleanNonTLSPool(std::vector& vPool, CCriticalSection& cs); - int threadSocketHandler(CNode* pnode, fd_set& fdsetRecv, fd_set& fdsetSend, fd_set& fdsetError); + int threadSocketHandler(CNode* pnode, const std::unordered_map& events); bool initialize(); }; }