diff --git a/Source/core/SocketPort.cpp b/Source/core/SocketPort.cpp index 5a487e84f..c8e953396 100644 --- a/Source/core/SocketPort.cpp +++ b/Source/core/SocketPort.cpp @@ -496,9 +496,15 @@ namespace Thunder { m_SendOffset = 0; if ((m_State.load(Core::memory_order::memory_order_relaxed) & (SocketPort::LINK | SocketPort::OPEN | SocketPort::MONITOR)) == (SocketPort::LINK | SocketPort::OPEN)) { - // Open up an accepted socket, but not yet added to the monitor. - m_State.fetch_or(SocketPort::UPDATE, Core::memory_order::memory_order_relaxed); - nStatus = Core::ERROR_NONE; + + if (Initialize() != Core::ERROR_NONE) { + nStatus = Core::ERROR_ABORTED; + } + else { + // Open up an accepted socket, but not yet added to the monitor. + m_State.fetch_or(SocketPort::UPDATE, Core::memory_order::memory_order_relaxed); + nStatus = Core::ERROR_INPROGRESS; + } } else { ASSERT((m_Socket == INVALID_SOCKET) && (m_State.load(Core::memory_order::memory_order_relaxed) == 0)); diff --git a/Source/cryptalgo/SecureSocketPort.cpp b/Source/cryptalgo/SecureSocketPort.cpp index 764b94398..073414867 100644 --- a/Source/cryptalgo/SecureSocketPort.cpp +++ b/Source/cryptalgo/SecureSocketPort.cpp @@ -122,27 +122,35 @@ bool SecureSocketPort::Certificate::Verify(string& errorMsg) const { SecureSocketPort::Handler::~Handler() { - if(_ssl != nullptr) { - SSL_free(static_cast(_ssl)); - } - if(_context != nullptr) { - SSL_CTX_free(static_cast(_context)); - } + ASSERT(IsClosed() == true); + Close(0); } uint32_t SecureSocketPort::Handler::Initialize() { uint32_t success = Core::ERROR_NONE; - _context = SSL_CTX_new(TLS_method()); + if (IsOpen() == true) { + _context = SSL_CTX_new(TLS_server_method()); + _handShaking = ACCEPTING; + } + else { + _context = SSL_CTX_new(TLS_method()); + _handShaking = CONNECTING; + } _ssl = SSL_new(static_cast(_context)); SSL_set_fd(static_cast(_ssl), static_cast(*this).Descriptor()); - SSL_CTX_set_options(static_cast(_context), SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); + SSL_CTX_set_options(static_cast(_context), SSL_OP_ALL | SSL_OP_NO_SSLv2); // Trust the same certificates as any other application if (SSL_CTX_set_default_verify_paths(static_cast(_context)) == 1) { success = Core::SocketPort::Initialize(); - } else { + + if (success == Core::ERROR_NONE) { + SSL_set_tlsext_host_name(static_cast(_ssl), RemoteNode().HostName().c_str()); + } + } + else { TRACE_L1("OpenSSL failed to load certificate store"); success = Core::ERROR_GENERAL; } @@ -151,12 +159,10 @@ uint32_t SecureSocketPort::Handler::Initialize() { } int32_t SecureSocketPort::Handler::Read(uint8_t buffer[], const uint16_t length) const { - int32_t result = SSL_read(static_cast(_ssl), buffer, length); - if (_handShaking != CONNECTED) { const_cast(*this).Update(); } - return (result); + return (SSL_read(static_cast(_ssl), buffer, length)); } int32_t SecureSocketPort::Handler::Write(const uint8_t buffer[], const uint16_t length) { @@ -171,8 +177,14 @@ uint32_t SecureSocketPort::Handler::Open(const uint32_t waitTime) { uint32_t SecureSocketPort::Handler::Close(const uint32_t waitTime) { if (_ssl != nullptr) { SSL_shutdown(static_cast(_ssl)); + SSL_free(static_cast(_ssl)); + _ssl = nullptr; + } + if (_context != nullptr) { + SSL_CTX_free(static_cast(_context)); + _context = nullptr; } - _handShaking = IDLE; + return(Core::SocketPort::Close(waitTime)); } @@ -199,43 +211,53 @@ void SecureSocketPort::Handler::ValidateHandShake() { if (!validationError.empty()) { TRACE_L1("OpenSSL certificate validation error for %s: %s", certificate.Subject().c_str(), validationError.c_str()); } - _handShaking = IDLE; + _handShaking = ERROR; Core::SocketPort::Unlock(); SetError(); } X509_free(x509cert); } else { - _handShaking = IDLE; + _handShaking = ERROR; SetError(); } } void SecureSocketPort::Handler::Update() { + if (IsOpen() == true) { int result; - if (_handShaking == IDLE) { - SSL_set_tlsext_host_name(static_cast(_ssl), RemoteNode().HostName().c_str()); - result = SSL_connect(static_cast(_ssl)); - if (result == 1) { - ValidateHandShake(); + ASSERT(_ssl != nullptr); + + if (_handShaking == CONNECTING) { + if ((result = SSL_connect(static_cast(_ssl))) == 1) { + _handShaking = EXCHANGE; } - else { - result = SSL_get_error(static_cast(_ssl), result); - if ((result == SSL_ERROR_WANT_READ) || (result == SSL_ERROR_WANT_WRITE)) { - _handShaking = EXCHANGE; - } + } + else if (_handShaking == ACCEPTING) { + if ((result = SSL_accept(static_cast(_ssl))) == 1) { + _handShaking = EXCHANGE; } } - else if (_handShaking == EXCHANGE) { - if (SSL_do_handshake(static_cast(_ssl)) == 1) { + + if (_handShaking == EXCHANGE) { + if ((result = SSL_do_handshake(static_cast(_ssl))) == 1) { ValidateHandShake(); } } + + if (result != 1) { + result = SSL_get_error(static_cast(_ssl), result); + if ((result != SSL_ERROR_WANT_READ) && (result != SSL_ERROR_WANT_WRITE)) { + _handShaking = ERROR; + } + else if (result == SSL_ERROR_WANT_WRITE) { + Trigger(); + } + } } - else if (_ssl != nullptr) { - _handShaking = IDLE; + else { _parent.StateChange(); } } diff --git a/Source/cryptalgo/SecureSocketPort.h b/Source/cryptalgo/SecureSocketPort.h index 280a44ae3..5fd037b40 100644 --- a/Source/cryptalgo/SecureSocketPort.h +++ b/Source/cryptalgo/SecureSocketPort.h @@ -64,9 +64,11 @@ namespace Crypto { class EXTERNAL Handler : public Core::SocketPort { private: enum state : uint8_t { - IDLE, + ACCEPTING, + CONNECTING, EXCHANGE, - CONNECTED + CONNECTED, + ERROR }; public: @@ -81,7 +83,7 @@ namespace Crypto { , _context(nullptr) , _ssl(nullptr) , _callback(nullptr) - , _handShaking(IDLE) { + , _handShaking(CONNECTING) { } ~Handler(); @@ -105,8 +107,6 @@ namespace Crypto { // Signal a state change, Opened, Closed or Accepted void StateChange() override { - - ASSERT(_context != nullptr); Update(); }; inline uint32_t Callback(IValidator* callback) {