From 7b93e6586ca5af61282687851c875d56a6a71c5f Mon Sep 17 00:00:00 2001 From: Pierre Wielders Date: Sun, 8 Dec 2024 00:26:42 +0100 Subject: [PATCH 1/2] [SSL] Integrate OpenSSL behind C++ wrappers for Key/Certificate and a Certificate Store. Client code: class WebClient : public Web::WebLinkType&> { ... } WebClient webConnector(Core::NodeId("catfact.ninja", 443)); webConnector.Link().Root(Crypto::CertificateStore::Default()); uint32_t result = webConnector.Open(10000); if (result != Core::ERROR_NONE) { printf("Could not open the connection, error: %d\n", result); } else { printf("Waiting for a number of seconds for the result!\n"); SleepMs(5000); printf("The answer should be in by now, lets close it!\n"); webConnector.Close(100) } Server/Client code: void OpeningSecuredServerPort() { Crypto::Certificate certificate(_T("D:/domotica/src/ca_certficates/onsite.crt")); Crypto::Key key(_T("D:/domotica/src/ca_certficates/onsite.key"), _T("")); Crypto::Certificate CA(_T("D:/domotica/src/ca_certficates/rootCA.pem")); Crypto::CertificateStore store; store.Add(CA); const Core::NodeId localNode(localHostName, tcpServerPort, Core::NodeId::TYPE_IPV4); // This is a listening socket as result of using SocketServerType which enables listening Crypto::SecureSocketServerType> server(certificate, key, localNode /* listening node*/); if (server.Open(maxWaitTimeMs) == Core::ERROR_NONE) { // Time to open a Client, see if I can get some data :-) WebSocketClient client(webSocketURIPath, webSocketProtocol, webSocketURIQuery, webSocketOrigin, false, true, rawSocket, localNode.AnyInterface(), localNode, sendBufferSize, receiveBufferSize, "WebSocketClient"); client.Link().Root(store); if (client.Open(3000) == Core::ERROR_NONE) { std::basic_string message(request, sizeof(request)); // Seems we have connections, now exchange a message client.Submit(message); // Sleep for some time so we can send and receive it :-) SleepMs(1000); WebSocketClient::Message response = client.Response(); printf("%s\n\n", response.c_str()); client.Close(Core::infinite); } } } --- Source/core/SocketServer.h | 45 ++- Source/cryptalgo/SecureSocketPort.cpp | 457 +++++++++++++++++++------- Source/cryptalgo/SecureSocketPort.h | 243 ++++++++++---- 3 files changed, 550 insertions(+), 195 deletions(-) diff --git a/Source/core/SocketServer.h b/Source/core/SocketServer.h index 0b439618c..2a2230e4a 100644 --- a/Source/core/SocketServer.h +++ b/Source/core/SocketServer.h @@ -30,7 +30,7 @@ namespace Core { template class SocketServerType { private: - typedef std::map> ClientMap; + using ClientMap = std::map>; public: template @@ -69,9 +69,7 @@ namespace Core { { move._atHead = true; } - ~IteratorType() - { - } + ~IteratorType() = default; IteratorType& operator=(const IteratorType& RHS) { @@ -130,38 +128,35 @@ namespace Core { typename std::list::iterator _iterator; }; - typedef IteratorType> Iterator; + using Iterator = IteratorType>; private: template class SocketHandler : public SocketListner { - private: + public: SocketHandler() = delete; + SocketHandler(SocketHandler&&) = delete; SocketHandler(const SocketHandler&) = delete; + SocketHandler& operator=(SocketHandler&&) = delete; SocketHandler& operator=(const SocketHandler&) = delete; - public: - SocketHandler(SocketServerType* parent) + SocketHandler(SocketServerType& parent) : SocketListner() , _nextClient(1) , _lock() , _clients() - , _parent(*parent) + , _parent(parent) { - - ASSERT(parent != nullptr); } - SocketHandler(const NodeId& listenNode, SocketServerType* parent) + SocketHandler(const NodeId& listenNode, SocketServerType& parent) : SocketListner(listenNode) , _nextClient(1) , _lock() , _clients() - , _parent(*parent) + , _parent(parent) { - - ASSERT(parent != nullptr); } - ~SocketHandler() + ~SocketHandler() override { SocketListner::Close(Core::infinite); CloseClients(0); @@ -267,7 +262,7 @@ namespace Core { // Do not change the Close() duration to a value >0. We should just test, but not wait for a statechange. // Waiting for a Statwchange might require, in the SocketPort imlementation of Close, WaitForCloseure with // parameter Core::infinite in case we have a faulthy socket. This call will than only return if the - // ResourceMonitor thread does report on CLosure of the socket. However, the ResourceMonitor thread might + // ResourceMonitor thread does report on Closure of the socket. However, the ResourceMonitor thread might // also be calling into here for an Accept. // In that case, the Accept will block on the _lock from this object as it is taken by this Cleanup call // running on a different thread but also this lock will not be freed as this cleanup thread is waiting @@ -353,23 +348,23 @@ namespace Core { SocketServerType& _parent; }; + public: + SocketServerType(SocketServerType&&) = delete; SocketServerType(const SocketServerType&) = delete; + SocketServerType& operator=(SocketServerType&&) = delete; SocketServerType& operator=(const SocketServerType&) = delete; - public: -PUSH_WARNING(DISABLE_WARNING_THIS_IN_MEMBER_INITIALIZER_LIST) + PUSH_WARNING(DISABLE_WARNING_THIS_IN_MEMBER_INITIALIZER_LIST) SocketServerType() - : _handler(this) + : _handler(*this) { } SocketServerType(const NodeId& listeningNode) - : _handler(listeningNode, this) - { - } -POP_WARNING() - ~SocketServerType() + : _handler(listeningNode, *this) { } + POP_WARNING() + ~SocketServerType() = default; public: inline uint32_t Open(const uint32_t waitTime) diff --git a/Source/cryptalgo/SecureSocketPort.cpp b/Source/cryptalgo/SecureSocketPort.cpp index f177b86c1..ac241e1b0 100644 --- a/Source/cryptalgo/SecureSocketPort.cpp +++ b/Source/cryptalgo/SecureSocketPort.cpp @@ -21,6 +21,22 @@ #include #include +#include +#include +#include +#include +#include +#include + +#ifdef __WINDOWS__ +#include +#include +#include +#include + +#pragma comment (lib, "crypt32.lib") +#pragma comment (lib, "cryptui.lib") +#endif #ifndef __WINDOWS__ namespace { @@ -51,38 +67,74 @@ namespace Thunder { namespace Crypto { - static Core::Time ASN1_ToTime(const ASN1_TIME* input) - { - Core::Time result; +static Core::Time ASN1_ToTime(const ASN1_TIME* input) +{ + Core::Time result; - if (input != nullptr) { - uint16_t year = 0; - const char* textVersion = reinterpret_cast(input->data); + if (input != nullptr) { + uint16_t year = 0; + const char* textVersion = reinterpret_cast(input->data); - if (input->type == V_ASN1_UTCTIME) - { - year = (textVersion[0] - '0') * 10 + (textVersion[1] - '0'); - year += (year < 70 ? 2000 : 1900); - textVersion = &textVersion[2]; - } - else if (input->type == V_ASN1_GENERALIZEDTIME) - { - year = (textVersion[0] - '0') * 1000 + (textVersion[1] - '0') * 100 + (textVersion[2] - '0') * 10 + (textVersion[3] - '0'); - textVersion = &textVersion[4]; - } - uint8_t month = ((textVersion[0] - '0') * 10 + (textVersion[1] - '0')) - 1; - uint8_t day = (textVersion[2] - '0') * 10 + (textVersion[3] - '0'); - uint8_t hour = (textVersion[4] - '0') * 10 + (textVersion[5] - '0'); - uint8_t minutes = (textVersion[6] - '0') * 10 + (textVersion[7] - '0'); - uint8_t seconds = (textVersion[8] - '0') * 10 + (textVersion[9] - '0'); - - /* Note: we did not adjust the time based on time zone information */ - result = Core::Time(year, month, day, hour, minutes, seconds, 0, false); + if (input->type == V_ASN1_UTCTIME) + { + year = (textVersion[0] - '0') * 10 + (textVersion[1] - '0'); + year += (year < 70 ? 2000 : 1900); + textVersion = &textVersion[2]; + } + else if (input->type == V_ASN1_GENERALIZEDTIME) + { + year = (textVersion[0] - '0') * 1000 + (textVersion[1] - '0') * 100 + (textVersion[2] - '0') * 10 + (textVersion[3] - '0'); + textVersion = &textVersion[4]; } - return (result); + uint8_t month = ((textVersion[0] - '0') * 10 + (textVersion[1] - '0')) - 1; + uint8_t day = (textVersion[2] - '0') * 10 + (textVersion[3] - '0'); + uint8_t hour = (textVersion[4] - '0') * 10 + (textVersion[5] - '0'); + uint8_t minutes = (textVersion[6] - '0') * 10 + (textVersion[7] - '0'); + uint8_t seconds = (textVersion[8] - '0') * 10 + (textVersion[9] - '0'); + + /* Note: we did not adjust the time based on time zone information */ + result = Core::Time(year, month, day, hour, minutes, seconds, 0, false); } + return (result); +} + +// ----------------------------------------------------------------------------- +// class Certificate +// ----------------------------------------------------------------------------- +Certificate::Certificate(const x509_st* certificate) + : _certificate(certificate) { + if (certificate != nullptr) { + X509_up_ref(const_cast(_certificate)); + } +} -string SecureSocketPort::Certificate::Issuer() const { +Certificate::Certificate(const TCHAR fileName[]) { + X509* cert = X509_new(); + BIO* bio_cert = BIO_new_file(fileName, "rb"); + PEM_read_bio_X509(bio_cert, &cert, NULL, NULL); + _certificate = cert; +} + +Certificate::Certificate(Certificate&& certificate) noexcept + : _certificate(certificate._certificate) { + certificate._certificate = nullptr; +} + +Certificate::Certificate(const Certificate& certificate) + : _certificate(certificate._certificate) { + if (_certificate != nullptr) { + X509_up_ref(const_cast(_certificate)); + } +} + +Certificate::~Certificate() +{ + if (_certificate != nullptr) { + X509_free(const_cast(_certificate)); + } +} + +string Certificate::Issuer() const { char buffer[1024]; buffer[0] = '\0'; X509_NAME_oneline(X509_get_issuer_name(_certificate), buffer, sizeof(buffer)); @@ -90,7 +142,7 @@ string SecureSocketPort::Certificate::Issuer() const { return (string(buffer)); } -string SecureSocketPort::Certificate::Subject() const { +string Certificate::Subject() const { char buffer[1024]; buffer[0] = '\0'; X509_NAME_oneline(X509_get_subject_name(_certificate), buffer, sizeof(buffer)); @@ -98,85 +150,166 @@ string SecureSocketPort::Certificate::Subject() const { return (string(buffer)); } -Core::Time SecureSocketPort::Certificate::ValidFrom() const { +Core::Time Certificate::ValidFrom() const { return(ASN1_ToTime(X509_get0_notBefore(_certificate))); } -Core::Time SecureSocketPort::Certificate::ValidTill() const { +Core::Time Certificate::ValidTill() const { return(ASN1_ToTime(X509_get0_notAfter(_certificate))); } -bool SecureSocketPort::Certificate::ValidHostname(const string& expectedHostname) const { - return (X509_check_host(_certificate, expectedHostname.data(), expectedHostname.size(), 0, nullptr) == 1); +bool Certificate::ValidHostname(const string& expectedHostname) const { + return (X509_check_host(const_cast(_certificate), expectedHostname.data(), expectedHostname.size(), 0, nullptr) == 1); } -bool SecureSocketPort::Certificate::Verify(string& errorMsg) const { - long error = SSL_get_verify_result(_context); - - if (error != X509_V_OK) { - errorMsg = X509_verify_cert_error_string(error); +// ----------------------------------------------------------------------------- +// class Key +// ----------------------------------------------------------------------------- +Key::Key(const evp_pkey_st* key) + : _key(key) { + if (_key != nullptr) { + EVP_PKEY_up_ref(const_cast(_key)); } +} - return error == X509_V_OK; +Key::Key(Key&& key) noexcept + : _key(key._key) { + key._key = nullptr; } +Key::Key(const Key& key) + : _key(key._key) { + if (_key != nullptr) { + EVP_PKEY_up_ref(const_cast(_key)); + } +} -SecureSocketPort::Handler::~Handler() { - ASSERT(IsClosed() == true); - Close(0); +Key::Key(const string& fileName) + : _key(nullptr) { + + BIO* bio_key = BIO_new_file(fileName.c_str(), "rb"); + + if (bio_key != nullptr) { + _key = PEM_read_bio_PUBKEY(bio_key, NULL, NULL, NULL); + + BIO_free(bio_key); + } } -uint32_t SecureSocketPort::Handler::Initialize() { - uint32_t success = Core::ERROR_NONE; +static int passwd_callback(char* buffer, int size, int /* flags */, void* password) +{ + int copied = std::min(static_cast(strlen(static_cast(password))), size); + memcpy(buffer, password, copied); + return copied; +} - if (IsOpen() == true) { - _context = SSL_CTX_new(TLS_server_method()); - _handShaking = ACCEPTING; +Key::Key(const string& fileName, const string& password) + : _key(nullptr) { + BIO* bio_key = BIO_new_file(fileName.c_str(), "rb"); + FILE* file = ::fopen(fileName.c_str(), "rt"); + + if (bio_key != nullptr) { + _key = PEM_read_bio_PrivateKey(bio_key, NULL, passwd_callback, const_cast(static_cast(password.c_str()))); + BIO_free(bio_key); } - else { - _context = SSL_CTX_new(TLS_method()); - _handShaking = CONNECTING; +} + +Key::~Key() +{ + if (_key != nullptr) { + EVP_PKEY_free(const_cast(_key)); } +} + +// ----------------------------------------------------------------------------- +// class CertificateStore +// ----------------------------------------------------------------------------- + +#ifdef __WINDOWS__ +static struct x509_store_st* CreateDefaultStore() +{ + HCERTSTORE hStore; + PCCERT_CONTEXT pContext = nullptr; + X509_STORE* store = X509_STORE_new(); - _ssl = SSL_new(static_cast(_context)); - SSL_set_fd(static_cast(_ssl), static_cast(*this).Descriptor()); - SSL_CTX_set_options(static_cast(_context), SSL_OP_ALL | SSL_OP_NO_SSLv2); + hStore = CertOpenSystemStore(NULL, _T("ROOT")); - // Trust the same certificates as any other application - if (SSL_CTX_set_default_verify_paths(static_cast(_context)) == 1) { - success = Core::SocketPort::Initialize(); + if (hStore != nullptr) { + while (pContext = CertEnumCertificatesInStore(hStore, pContext)) { + X509* x509 = d2i_X509(nullptr, (const unsigned char**)&pContext->pbCertEncoded, pContext->cbCertEncoded); - if (success == Core::ERROR_NONE) { - SSL_set_tlsext_host_name(static_cast(_ssl), RemoteNode().HostName().c_str()); + if (x509 != nullptr) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + } } + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); } - else { - TRACE_L1("OpenSSL failed to load certificate store"); - success = Core::ERROR_GENERAL; + + return (store); +} +#else +static struct x509_store_st* CreateDefaultStore() +{ + X509_STORE* store = X509_STORE_new(); + + const char* dir = getenv(X509_get_default_cert_dir_env()); + + if (dir == nullptr) { + dir = X509_get_default_cert_dir(); } - return success; + X509_STORE_load_path(store, dir); + + return (store); } +#endif -int32_t SecureSocketPort::Handler::Read(uint8_t buffer[], const uint16_t length) const { - if (_handShaking != CONNECTED) { - const_cast(*this).Update(); +/* static */ struct x509_store_st* CertificateStore::_default = CreateDefaultStore(); + +CertificateStore::CertificateStore() + : _store(X509_STORE_new()) { +} + +CertificateStore::CertificateStore(CertificateStore&& move) noexcept + : _store(move._store) { + move._store = nullptr; +} + +CertificateStore::CertificateStore(const CertificateStore& copy) + : _store(copy._store) { + if (_store != nullptr) { + X509_STORE_up_ref(const_cast(_store)); } - return (SSL_read(static_cast(_ssl), buffer, length)); } -int32_t SecureSocketPort::Handler::Write(const uint8_t buffer[], const uint16_t length) { - return (SSL_write(static_cast(_ssl), buffer, length)); +CertificateStore::CertificateStore(struct x509_store_st* store) + : _store(store) { + if (_store != nullptr) { + X509_STORE_up_ref(const_cast(_store)); + } } +CertificateStore::~CertificateStore() { + if (_store != nullptr) { + X509_STORE_free(const_cast(_store)); + } +} -uint32_t SecureSocketPort::Handler::Open(const uint32_t waitTime) { - return (Core::SocketPort::Open(waitTime)); +void CertificateStore::Add(const Certificate& certificate) { + const struct x509_st* cert = certificate; + X509_STORE_add_cert(_store, const_cast(cert)); } -uint32_t SecureSocketPort::Handler::Close(const uint32_t waitTime) { +// ----------------------------------------------------------------------------- +// class SecureSocketPort::Handler +// ----------------------------------------------------------------------------- +SecureSocketPort::Handler::~Handler() { + ASSERT(IsClosed() == true); + Close(0); + if (_ssl != nullptr) { - SSL_shutdown(static_cast(_ssl)); SSL_free(static_cast(_ssl)); _ssl = nullptr; } @@ -184,43 +317,146 @@ uint32_t SecureSocketPort::Handler::Close(const uint32_t waitTime) { SSL_CTX_free(static_cast(_context)); _context = nullptr; } +} + +void SecureSocketPort::Handler::CreateContext(const bool server) { + _context = SSL_CTX_new(server ? TLS_server_method() : TLS_method()); + if (_context != nullptr) { + _ssl = SSL_new(_context); + + if (_ssl == nullptr) { + SSL_CTX_free(_context); + _context = nullptr; + } + else { + constexpr unsigned long options = SSL_OP_ALL | SSL_OP_NO_SSLv2; + + VARIABLE_IS_NOT_USED unsigned long bitmask = SSL_CTX_set_options(_context, options); + + ASSERT((bitmask & options) == options); + + if (server == true) { + SSL_set_accept_state(_ssl); + } + else { + SSL_set_connect_state(_ssl); + } + } + } +} + +uint32_t SecureSocketPort::Handler::Initialize() { + bool initialized = false; + + ASSERT(_context != nullptr); + ASSERT(_ssl != nullptr); + + if (SSL_set_fd(static_cast(_ssl), static_cast(*this).Descriptor()) == 1) { + SSL_set_tlsext_host_name(_ssl, RemoteNode().HostName().c_str()); + initialized = Core::SocketPort::Initialize(); + } + + return (initialized); +} + +int32_t SecureSocketPort::Handler::Read(uint8_t buffer[], const uint16_t length) const { + + ASSERT(_handShaking != ERROR); + + if (_handShaking != OPEN) { + const_cast(*this).Update(); + } + + return (SSL_read(static_cast(_ssl), buffer, length)); +} + +int32_t SecureSocketPort::Handler::Write(const uint8_t buffer[], const uint16_t length) { + + ASSERT(_handShaking != ERROR); + + if (_handShaking != OPEN) { + Update(); + } + + return (SSL_write(_ssl, buffer, length)); +} + +uint32_t SecureSocketPort::Handler::Open(const uint32_t waitTime) { + return (Core::SocketPort::Open(waitTime)); +} + +uint32_t SecureSocketPort::Handler::Close(const uint32_t waitTime) { + ASSERT(_ssl != nullptr); + SSL_shutdown(static_cast(_ssl)); return(Core::SocketPort::Close(waitTime)); } +uint32_t SecureSocketPort::Handler::Certificate(const Crypto::Certificate& certificate, const Crypto::Key& key) { + // Load server certificate and private key + const struct x509_st* cert = certificate; + const struct evp_pkey_st* base_key = key; + uint32_t result = Core::ERROR_BAD_REQUEST; + + if (SSL_CTX_use_certificate(_context, const_cast(cert)) == 1) { + result = Core::ERROR_UNKNOWN_KEY; + if (SSL_CTX_use_PrivateKey(_context, const_cast(base_key)) == 1) { + result = Core::ERROR_NONE; + } + } + + return (result); +} + +uint32_t SecureSocketPort::Handler::Root(const CertificateStore& certStore) { + const struct x509_store_st* store = certStore; + + SSL_CTX_set_cert_store(_context, const_cast(store)); + + return (Core::ERROR_NONE); +} + void SecureSocketPort::Handler::ValidateHandShake() { - // Step 1: verify a server certificate was presented during the negotiation - X509* x509cert = SSL_get_peer_certificate(static_cast(_ssl)); - if (x509cert != nullptr) { - Core::SocketPort::Lock(); + // Step 1: verify a certificate was presented during the negotiation + X509* x509cert = SSL_get_peer_certificate(_ssl); - Certificate certificate(x509cert, static_cast(_ssl)); + if (x509cert == nullptr) { + _handShaking = ERROR; + SetError(); + _parent.StateChange(); + } + else { + long error; + string validationError; + Crypto::Certificate certificate(x509cert); // Step 2: Validate certificate - use custom IValidator instance if available or if self signed // certificates are needed :-) - string validationError; - if (_callback && _callback->Validate(certificate) == true) { - _handShaking = CONNECTED; - Core::SocketPort::Unlock(); - _parent.StateChange(); - } else if (certificate.Verify(validationError) && certificate.ValidHostname(RemoteNode().HostName())) { - _handShaking = CONNECTED; - Core::SocketPort::Unlock();\ - _parent.StateChange(); - } else { - if (!validationError.empty()) { - TRACE_L1("OpenSSL certificate validation error for %s: %s", certificate.Subject().c_str(), validationError.c_str()); + if (_callback != nullptr) { + if (_callback->Validate(certificate) == false) { + _handShaking = ERROR; + SetError(); + _parent.StateChange(); } + else { + _handShaking = OPEN; + _parent.StateChange(); + } + } + // SSL handshake does an implicit verification, its result is: + else if ((error = SSL_get_verify_result(_ssl)) != X509_V_OK) { + // string errorMsg = X509_verify_cert_error_string(error); _handShaking = ERROR; - Core::SocketPort::Unlock(); SetError(); + _parent.StateChange(); + } + else { + _handShaking = OPEN; + _parent.StateChange(); } X509_free(x509cert); - } else { - _handShaking = ERROR; - SetError(); - } + } } void SecureSocketPort::Handler::Update() { @@ -230,36 +466,29 @@ void SecureSocketPort::Handler::Update() { ASSERT(_ssl != nullptr); - if (_handShaking == CONNECTING) { - if ((result = SSL_connect(static_cast(_ssl))) == 1) { - _handShaking = EXCHANGE; - } - } - else if (_handShaking == ACCEPTING) { - if ((result = SSL_accept(static_cast(_ssl))) == 1) { - _handShaking = EXCHANGE; - } - } - if (_handShaking == EXCHANGE) { - if ((result = SSL_do_handshake(static_cast(_ssl))) == 1) { + if ((result = SSL_do_handshake(_ssl)) == 1) { ValidateHandShake(); } - } - - if (result != 1) { - result = SSL_get_error(static_cast(_ssl), result); - if ((result != SSL_ERROR_WANT_READ) && (result != SSL_ERROR_WANT_WRITE)) { - _handShaking = ERROR; - } - else if (result == SSL_ERROR_WANT_WRITE) { - Trigger(); + else { + result = SSL_get_error(_ssl, result); + + if (result == SSL_ERROR_WANT_WRITE) { + Trigger(); + } + else if (result != SSL_ERROR_WANT_READ) { + _handShaking = ERROR; + } } } } else { + _handShaking = EXCHANGE; _parent.StateChange(); } } +SecureSocketPort::~SecureSocketPort() { +} + } } // namespace Thunder::Crypto diff --git a/Source/cryptalgo/SecureSocketPort.h b/Source/cryptalgo/SecureSocketPort.h index 5fd037b40..1482219bf 100644 --- a/Source/cryptalgo/SecureSocketPort.h +++ b/Source/cryptalgo/SecureSocketPort.h @@ -21,69 +21,163 @@ #include "Module.h" -struct x509_store_ctx_st; -struct x509_st; struct ssl_st; +struct ssl_ctx_st; +struct x509_st; +struct evp_pkey_st; +struct x509_store_st; namespace Thunder { namespace Crypto { - class EXTERNAL SecureSocketPort : public Core::IResource { + class EXTERNAL Certificate { public: - class EXTERNAL Certificate { - public: - Certificate() = delete; - Certificate(Certificate&&) = delete; - Certificate(const Certificate&) = delete; + Certificate() = delete; + Certificate& operator=(Certificate&&) = delete; + Certificate& operator=(const Certificate&) = delete; - Certificate(x509_st* certificate, const ssl_st* context) - : _certificate(certificate) - , _context(context) { - } - ~Certificate() = default; + Certificate(const x509_st* certificate); + Certificate(const TCHAR fileName[]); + Certificate(Certificate&& move) noexcept; + Certificate(const Certificate& copy); + ~Certificate(); - public: - string Issuer() const; - string Subject() const; - Core::Time ValidFrom() const; - Core::Time ValidTill() const; - bool ValidHostname(const string& expectedHostname) const; - bool Verify(string& errorMsg) const; + public: + string Issuer() const; + string Subject() const; + Core::Time ValidFrom() const; + Core::Time ValidTill() const; + bool ValidHostname(const string& expectedHostname) const; - private: - x509_st* _certificate; - const ssl_st* _context; - }; - struct IValidator { - virtual ~IValidator() = default; + inline operator const struct x509_st* () const { + return (_certificate); + } + + private: + const x509_st* _certificate; + }; + class EXTERNAL Key { + public: + Key() = delete; + Key& operator=(Key&&) = delete; + Key& operator=(const Key&) = delete; + + Key(const evp_pkey_st* key); + Key(const string& fileName); + Key(const string& fileName, const string& password); + Key(Key&& move) noexcept; + Key(const Key& copy); + ~Key(); + + public: + inline operator const evp_pkey_st* () const { + return (_key); + } + + private: + const evp_pkey_st* _key; + }; + class EXTERNAL CertificateStore { + public: + CertificateStore& operator=(CertificateStore&&) = delete; + CertificateStore& operator=(const CertificateStore&) = delete; + + CertificateStore(); + CertificateStore(CertificateStore&&) noexcept; + CertificateStore(const CertificateStore&); + CertificateStore(struct x509_store_st*); + ~CertificateStore(); + + public: + static CertificateStore& Default() { + static CertificateStore defaultStore(_default); + return (defaultStore); + } + void Add(const Certificate& cert); + inline operator const x509_store_st* () const { + return (_store); + } + + + private: + struct x509_store_st* _store; + static struct x509_store_st* _default; + }; + class EXTERNAL SecureSocketPort : public Core::IResource { + public: + struct EXTERNAL IValidate { + virtual ~IValidate() = default; - virtual bool Validate(const Certificate& certificate) const = 0; + // Client part, override custom validation + virtual bool Validate(const Certificate&) const = 0; }; private: class EXTERNAL Handler : public Core::SocketPort { private: enum state : uint8_t { - ACCEPTING, - CONNECTING, EXCHANGE, - CONNECTED, + OPEN, ERROR }; public: Handler(Handler&&) = delete; Handler(const Handler&) = delete; + Handler& operator=(Handler&&) = delete; Handler& operator=(const Handler&) = delete; - template - Handler(SecureSocketPort& parent, Args&&... args) - : Core::SocketPort(args...) + Handler(SecureSocketPort& parent, + const enumType socketType, + const Core::NodeId& localNode, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize) + : SocketPort(socketType, localNode, remoteNode, sendBufferSize, receiveBufferSize) + , _parent(parent) + , _callback(nullptr) + , _handShaking(EXCHANGE) { + CreateContext(false); + } + Handler(SecureSocketPort& parent, + const enumType socketType, + const Core::NodeId& localNode, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize, + const uint32_t socketSendBufferSize, + const uint32_t socketReceiveBufferSize) + : SocketPort(socketType, localNode, remoteNode, sendBufferSize, receiveBufferSize, socketSendBufferSize, socketReceiveBufferSize) , _parent(parent) - , _context(nullptr) - , _ssl(nullptr) , _callback(nullptr) - , _handShaking(CONNECTING) { + , _handShaking(EXCHANGE) { + CreateContext(false); + } + Handler(SecureSocketPort& parent, + const enumType socketType, + const SOCKET& connector, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize) + : SocketPort(socketType, connector, remoteNode, sendBufferSize, receiveBufferSize) + , _parent(parent) + , _callback(nullptr) + , _handShaking(EXCHANGE) { + CreateContext(true); + } + Handler(SecureSocketPort& parent, + const enumType socketType, + const SOCKET& connector, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize, + const uint32_t socketSendBufferSize, + const uint32_t socketReceiveBufferSize) + : SocketPort(socketType, connector, remoteNode, sendBufferSize, receiveBufferSize, socketSendBufferSize, socketReceiveBufferSize) + , _parent(parent) + , _callback(nullptr) + , _handShaking(EXCHANGE) { + CreateContext(true); } ~Handler(); @@ -97,57 +191,53 @@ namespace Crypto { uint32_t Close(const uint32_t waitTime); // Methods to extract and insert data into the socket buffers - uint16_t SendData(uint8_t* dataFrame, const uint16_t maxSendSize) override { + inline uint16_t SendData(uint8_t* dataFrame, const uint16_t maxSendSize) override { return (_parent.SendData(dataFrame, maxSendSize)); } - uint16_t ReceiveData(uint8_t* dataFrame, const uint16_t receivedSize) override { + inline uint16_t ReceiveData(uint8_t* dataFrame, const uint16_t receivedSize) override { return (_parent.ReceiveData(dataFrame, receivedSize)); } // Signal a state change, Opened, Closed or Accepted - void StateChange() override { + inline void StateChange() override { Update(); }; - inline uint32_t Callback(IValidator* callback) { - uint32_t result = Core::ERROR_ILLEGAL_STATE; - + inline void Validate(const IValidate* callback) { Core::SocketPort::Lock(); - ASSERT((callback == nullptr) || (_callback == nullptr)); + ASSERT((callback == nullptr) ^ (_callback == nullptr)); - if ((callback == nullptr) || (_callback == nullptr)) { - _callback = callback; - result = Core::ERROR_NONE; - } + _callback = callback; Core::SocketPort::Unlock(); - - return (result); } + uint32_t Certificate(const Crypto::Certificate& certificate, const Crypto::Key& key); + uint32_t Root(const CertificateStore& store); private: void Update(); void ValidateHandShake(); + void CreateContext(const bool server); private: SecureSocketPort& _parent; - void* _context; - void* _ssl; - IValidator* _callback; + struct ssl_ctx_st* _context; + struct ssl_st* _ssl; + const IValidate* _callback; mutable state _handShaking; }; public: SecureSocketPort(SecureSocketPort&&) = delete; SecureSocketPort(const SecureSocketPort&) = delete; + SecureSocketPort& operator=(SecureSocketPort&&) = delete; SecureSocketPort& operator=(const SecureSocketPort&) = delete; template SecureSocketPort(Args&&... args) : _handler(*this, args...) { } - ~SecureSocketPort() override { - } + ~SecureSocketPort() override; public: inline bool IsOpen() const @@ -195,10 +285,17 @@ namespace Crypto { inline void Trigger() { _handler.Trigger(); } - inline uint32_t Callback(IValidator* callback) { - return (_handler.Callback(callback)); + inline void Validate(const IValidate* callback) { + _handler.Validate(callback); + } + inline uint32_t Certificate(const Crypto::Certificate& certificate, const Crypto::Key& key) { + return (_handler.Certificate(certificate, key)); + } + inline uint32_t Root(const CertificateStore& store) { + return (_handler.Root(store)); } + // // Core::IResource interface // ------------------------------------------------------------------------ @@ -225,5 +322,39 @@ namespace Crypto { private: Handler _handler; }; + + template + class SecureSocketServerType : public Core::SocketServerType { + public: + SecureSocketServerType() = delete; + SecureSocketServerType(SecureSocketServerType&&) = delete; + SecureSocketServerType(const SecureSocketServerType&) = delete; + SecureSocketServerType& operator=(SecureSocketServerType&&) = delete; + SecureSocketServerType& operator=(const SecureSocketServerType&) = delete; + + SecureSocketServerType(const Certificate& certificate, const Key& key) + : Core::SocketServerType() + , _certificate(certificate) + , _key(key) { + } + SecureSocketServerType(const Certificate& certificate, const Key& key, const Core::NodeId& serverNode) + : Core::SocketServerType(serverNode) + , _certificate(certificate) + , _key(key) { + } + ~SecureSocketServerType() = default; + + public: + const Crypto::Certificate& Certificate() const { + return (_certificate); + } + const Crypto::Key& Key() const { + return (_key); + } + + private: + Crypto::Certificate _certificate; + Crypto::Key _key; + }; } } From 6e92ffd95cab814aee1488ddf91792b0a5925b1d Mon Sep 17 00:00:00 2001 From: Pierre Wielders Date: Sun, 8 Dec 2024 11:56:58 +0100 Subject: [PATCH 2/2] [REFACTOR] The constructors are fixed, use them fixed. --- Source/core/SocketServer.h | 2 +- Source/cryptalgo/SecureSocketPort.cpp | 62 ++++++++++++++----- Source/cryptalgo/SecureSocketPort.h | 87 +++++++++++++-------------- 3 files changed, 90 insertions(+), 61 deletions(-) diff --git a/Source/core/SocketServer.h b/Source/core/SocketServer.h index 2a2230e4a..b76926190 100644 --- a/Source/core/SocketServer.h +++ b/Source/core/SocketServer.h @@ -281,7 +281,7 @@ namespace Core { _lock.Unlock(); } - virtual void Accept(SOCKET& newClient, const NodeId& remoteId) + void Accept(SOCKET& newClient, const NodeId& remoteId) override { ProxyType client = ProxyType::Create(newClient, remoteId, &_parent); diff --git a/Source/cryptalgo/SecureSocketPort.cpp b/Source/cryptalgo/SecureSocketPort.cpp index ac241e1b0..0b5379583 100644 --- a/Source/cryptalgo/SecureSocketPort.cpp +++ b/Source/cryptalgo/SecureSocketPort.cpp @@ -305,22 +305,52 @@ void CertificateStore::Add(const Certificate& certificate) { // ----------------------------------------------------------------------------- // class SecureSocketPort::Handler // ----------------------------------------------------------------------------- +SecureSocketPort::Handler::Handler(SecureSocketPort& parent, + const enumType socketType, + const Core::NodeId& localNode, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize, + const uint32_t socketSendBufferSize, + const uint32_t socketReceiveBufferSize) + : SocketPort(socketType, localNode, remoteNode, sendBufferSize, receiveBufferSize, socketSendBufferSize, socketReceiveBufferSize) + , _parent(parent) + , _callback(nullptr) + , _handShaking(EXCHANGE) { + CreateContext(TLS_method()); +} + +SecureSocketPort::Handler::Handler(SecureSocketPort& parent, + const enumType socketType, + const SOCKET& connector, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize, + const uint32_t socketSendBufferSize, + const uint32_t socketReceiveBufferSize) + : SocketPort(socketType, connector, remoteNode, sendBufferSize, receiveBufferSize, socketSendBufferSize, socketReceiveBufferSize) + , _parent(parent) + , _callback(nullptr) + , _handShaking(EXCHANGE) { + CreateContext(TLS_server_method()); +} + SecureSocketPort::Handler::~Handler() { ASSERT(IsClosed() == true); Close(0); if (_ssl != nullptr) { - SSL_free(static_cast(_ssl)); + SSL_free(_ssl); _ssl = nullptr; } if (_context != nullptr) { - SSL_CTX_free(static_cast(_context)); + SSL_CTX_free(_context); _context = nullptr; } } -void SecureSocketPort::Handler::CreateContext(const bool server) { - _context = SSL_CTX_new(server ? TLS_server_method() : TLS_method()); +void SecureSocketPort::Handler::CreateContext(const struct ssl_method_st* method) { + _context = SSL_CTX_new(method); if (_context != nullptr) { _ssl = SSL_new(_context); @@ -334,13 +364,6 @@ void SecureSocketPort::Handler::CreateContext(const bool server) { VARIABLE_IS_NOT_USED unsigned long bitmask = SSL_CTX_set_options(_context, options); ASSERT((bitmask & options) == options); - - if (server == true) { - SSL_set_accept_state(_ssl); - } - else { - SSL_set_connect_state(_ssl); - } } } } @@ -351,8 +374,15 @@ uint32_t SecureSocketPort::Handler::Initialize() { ASSERT(_context != nullptr); ASSERT(_ssl != nullptr); - if (SSL_set_fd(static_cast(_ssl), static_cast(*this).Descriptor()) == 1) { + if (SSL_set_fd(_ssl, static_cast(*this).Descriptor()) == 1) { SSL_set_tlsext_host_name(_ssl, RemoteNode().HostName().c_str()); + if (IsOpen() == true) { + SSL_set_accept_state(_ssl); + } + else { + SSL_set_connect_state(_ssl); + } + initialized = Core::SocketPort::Initialize(); } @@ -367,18 +397,20 @@ int32_t SecureSocketPort::Handler::Read(uint8_t buffer[], const uint16_t length) const_cast(*this).Update(); } - return (SSL_read(static_cast(_ssl), buffer, length)); + return (SSL_read(_ssl, buffer, length)); } int32_t SecureSocketPort::Handler::Write(const uint8_t buffer[], const uint16_t length) { ASSERT(_handShaking != ERROR); + uint32_t result = SSL_write(_ssl, buffer, length); + if (_handShaking != OPEN) { Update(); } - return (SSL_write(_ssl, buffer, length)); + return (result); } uint32_t SecureSocketPort::Handler::Open(const uint32_t waitTime) { @@ -387,7 +419,7 @@ uint32_t SecureSocketPort::Handler::Open(const uint32_t waitTime) { uint32_t SecureSocketPort::Handler::Close(const uint32_t waitTime) { ASSERT(_ssl != nullptr); - SSL_shutdown(static_cast(_ssl)); + SSL_shutdown(_ssl); return(Core::SocketPort::Close(waitTime)); } diff --git a/Source/cryptalgo/SecureSocketPort.h b/Source/cryptalgo/SecureSocketPort.h index 1482219bf..c1ff30fb8 100644 --- a/Source/cryptalgo/SecureSocketPort.h +++ b/Source/cryptalgo/SecureSocketPort.h @@ -23,6 +23,7 @@ struct ssl_st; struct ssl_ctx_st; +struct ssl_method_st; struct x509_st; struct evp_pkey_st; struct x509_store_st; @@ -89,9 +90,8 @@ namespace Crypto { ~CertificateStore(); public: - static CertificateStore& Default() { - static CertificateStore defaultStore(_default); - return (defaultStore); + static CertificateStore Default() { + return (CertificateStore(_default)); } void Add(const Certificate& cert); inline operator const x509_store_st* () const { @@ -128,43 +128,13 @@ namespace Crypto { Handler& operator=(const Handler&) = delete; Handler(SecureSocketPort& parent, - const enumType socketType, - const Core::NodeId& localNode, - const Core::NodeId& remoteNode, - const uint16_t sendBufferSize, - const uint16_t receiveBufferSize) - : SocketPort(socketType, localNode, remoteNode, sendBufferSize, receiveBufferSize) - , _parent(parent) - , _callback(nullptr) - , _handShaking(EXCHANGE) { - CreateContext(false); - } - Handler(SecureSocketPort& parent, const enumType socketType, const Core::NodeId& localNode, const Core::NodeId& remoteNode, const uint16_t sendBufferSize, const uint16_t receiveBufferSize, const uint32_t socketSendBufferSize, - const uint32_t socketReceiveBufferSize) - : SocketPort(socketType, localNode, remoteNode, sendBufferSize, receiveBufferSize, socketSendBufferSize, socketReceiveBufferSize) - , _parent(parent) - , _callback(nullptr) - , _handShaking(EXCHANGE) { - CreateContext(false); - } - Handler(SecureSocketPort& parent, - const enumType socketType, - const SOCKET& connector, - const Core::NodeId& remoteNode, - const uint16_t sendBufferSize, - const uint16_t receiveBufferSize) - : SocketPort(socketType, connector, remoteNode, sendBufferSize, receiveBufferSize) - , _parent(parent) - , _callback(nullptr) - , _handShaking(EXCHANGE) { - CreateContext(true); - } + const uint32_t socketReceiveBufferSize); Handler(SecureSocketPort& parent, const enumType socketType, const SOCKET& connector, @@ -172,13 +142,7 @@ namespace Crypto { const uint16_t sendBufferSize, const uint16_t receiveBufferSize, const uint32_t socketSendBufferSize, - const uint32_t socketReceiveBufferSize) - : SocketPort(socketType, connector, remoteNode, sendBufferSize, receiveBufferSize, socketSendBufferSize, socketReceiveBufferSize) - , _parent(parent) - , _callback(nullptr) - , _handShaking(EXCHANGE) { - CreateContext(true); - } + const uint32_t socketReceiveBufferSize); ~Handler(); public: @@ -217,7 +181,7 @@ namespace Crypto { private: void Update(); void ValidateHandShake(); - void CreateContext(const bool server); + void CreateContext(const struct ssl_method_st* method); private: SecureSocketPort& _parent; @@ -233,10 +197,43 @@ namespace Crypto { SecureSocketPort& operator=(SecureSocketPort&&) = delete; SecureSocketPort& operator=(const SecureSocketPort&) = delete; - template - SecureSocketPort(Args&&... args) - : _handler(*this, args...) { + SecureSocketPort( + const Core::SocketPort::enumType socketType, + const Core::NodeId& localNode, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize) + : _handler(*this, socketType, localNode, remoteNode, sendBufferSize, receiveBufferSize, sendBufferSize, receiveBufferSize) { } + SecureSocketPort( + const Core::SocketPort::enumType socketType, + const Core::NodeId& localNode, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize, + const uint32_t socketSendBufferSize, + const uint32_t socketReceiveBufferSize) + : _handler(*this, socketType, localNode, remoteNode, sendBufferSize, receiveBufferSize, socketSendBufferSize, socketReceiveBufferSize) { + } + SecureSocketPort( + const Core::SocketPort::enumType socketType, + const SOCKET& connector, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize) + : _handler(*this, socketType, connector, remoteNode, sendBufferSize, receiveBufferSize, sendBufferSize, receiveBufferSize) { + } + SecureSocketPort( + const Core::SocketPort::enumType socketType, + const SOCKET& connector, + const Core::NodeId& remoteNode, + const uint16_t sendBufferSize, + const uint16_t receiveBufferSize, + const uint32_t socketSendBufferSize, + const uint32_t socketReceiveBufferSize) + : _handler(*this, socketType, connector, remoteNode, sendBufferSize, receiveBufferSize, socketSendBufferSize, socketReceiveBufferSize) { + } + ~SecureSocketPort() override; public: