Skip to content

Commit

Permalink
Fix race issue with TcpConnectionService
Browse files Browse the repository at this point in the history
When connecting, failure would not always generate an error and stop the TCP Service.
Fix this by making the socket blocking again and use MSG_NOWAIT for only the receive function.
  • Loading branch information
Oipo committed Oct 16, 2024
1 parent f02cafc commit e1f832d
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 82 deletions.
4 changes: 0 additions & 4 deletions include/ichor/services/network/tcp/TcpConnectionService.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ namespace Ichor {
* - "Socket" int - An existing socket to manage (required if Address/Port are not present)
* - "Priority" uint64_t - Which priority to use for inserted events (default INTERNAL_EVENT_PRIORITY)
* - "TimeoutSendUs" int64_t - Timeout in microseconds for send calls (default 250'000)
* - "TimeoutRecvUs" int64_t - Timeout in microseconds for recv calls (default 250'000)
*/
class TcpConnectionService final : public IConnectionService, public AdvancedService<TcpConnectionService> {
public:
Expand Down Expand Up @@ -47,13 +46,10 @@ namespace Ichor {

friend DependencyRegister;

static uint64_t tcpConnId;
int _socket;
uint64_t _id;
uint64_t _attempts;
uint64_t _priority;
int64_t _sendTimeout{250'000};
int64_t _recvTimeout{250'000};
bool _quit;
ILogger *_logger{};
ITimerFactory *_timerFactory{};
Expand Down
1 change: 0 additions & 1 deletion include/ichor/services/network/tcp/TcpHostService.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ namespace Ichor {
* - "Port" uint16_t - What port to bind to (required)
* - "Priority" uint64_t - Which priority to use for inserted events (default INTERNAL_EVENT_PRIORITY)
* - "TimeoutSendUs" int64_t - Timeout in microseconds for send calls (default 250'000)
* - "TimeoutRecvUs" int64_t - Timeout in microseconds for recv calls (default 250'000)
*/
class TcpHostService final : public IHostService, public AdvancedService<TcpHostService> {
public:
Expand Down
138 changes: 72 additions & 66 deletions src/services/network/tcp/TcpConnectionService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
#include <poll.h>
#include <thread>

uint64_t Ichor::TcpConnectionService::tcpConnId{};

Ichor::TcpConnectionService::TcpConnectionService(DependencyRegister &reg, Properties props) : AdvancedService(std::move(props)), _socket(-1), _id(tcpConnId++), _attempts(), _priority(INTERNAL_EVENT_PRIORITY), _quit() {
Ichor::TcpConnectionService::TcpConnectionService(DependencyRegister &reg, Properties props) : AdvancedService(std::move(props)), _socket(-1), _attempts(), _priority(INTERNAL_EVENT_PRIORITY), _quit() {
reg.registerDependency<ILogger>(this, DependencyFlags::NONE);
reg.registerDependency<ITimerFactory>(this, DependencyFlags::REQUIRED);
}
Expand All @@ -26,9 +24,6 @@ Ichor::Task<tl::expected<void, Ichor::StartError>> Ichor::TcpConnectionService::
if(auto propIt = getProperties().find("TimeoutSendUs"); propIt != getProperties().end()) {
_sendTimeout = Ichor::any_cast<int64_t>(propIt->second);
}
if(auto propIt = getProperties().find("TimeoutRecvUs"); propIt != getProperties().end()) {
_recvTimeout = Ichor::any_cast<int64_t>(propIt->second);
}

if(getProperties().contains("Socket")) {
if(auto propIt = getProperties().find("Socket"); propIt != getProperties().end()) {
Expand All @@ -39,26 +34,23 @@ Ichor::Task<tl::expected<void, Ichor::StartError>> Ichor::TcpConnectionService::
::setsockopt(_socket, IPPROTO_TCP, TCP_NODELAY, &setting, sizeof(setting));

timeval timeout{};
timeout.tv_usec = _recvTimeout;
setsockopt(_socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout));
timeout.tv_usec = _sendTimeout;
setsockopt(_socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout));

auto flags = ::fcntl(_socket, F_GETFL, 0);
::fcntl(_socket, F_SETFL, flags | O_NONBLOCK);
ICHOR_LOG_TRACE(_logger, "[{}] Starting TCP connection for existing socket", _id);
ICHOR_LOG_DEBUG(_logger, "[{}] Starting TCP connection for existing socket", getServiceId());
} else {
auto addrIt = getProperties().find("Address");
auto portIt = getProperties().find("Port");

if(addrIt == getProperties().end()) {
ICHOR_LOG_ERROR(_logger, "[{}] Missing address", _id);
ICHOR_LOG_ERROR(_logger, "[{}] Missing address", getServiceId());
co_return tl::unexpected(StartError::FAILED);
}
if(portIt == getProperties().end()) {
ICHOR_LOG_ERROR(_logger, "[{}] Missing port", _id);
ICHOR_LOG_ERROR(_logger, "[{}] Missing port", getServiceId());
co_return tl::unexpected(StartError::FAILED);
}
ICHOR_LOG_TRACE(_logger, "[{}] connecting to {}:{}", getServiceId(), Ichor::any_cast<std::string&>(addrIt->second), Ichor::any_cast<uint16_t>(portIt->second));

// The start function possibly gets called multiple times due to trying to recover from not being able to connect
if(_socket == -1) {
Expand All @@ -72,14 +64,9 @@ Ichor::Task<tl::expected<void, Ichor::StartError>> Ichor::TcpConnectionService::
::setsockopt(_socket, IPPROTO_TCP, TCP_NODELAY, &setting, sizeof(setting));

timeval timeout{};
timeout.tv_usec = _recvTimeout;
setsockopt(_socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout));
timeout.tv_usec = _sendTimeout;
setsockopt(_socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout));

auto flags = ::fcntl(_socket, F_GETFL, 0);
::fcntl(_socket, F_SETFL, flags | O_NONBLOCK);

sockaddr_in address{};
address.sin_family = AF_INET;
address.sin_port = htons(Ichor::any_cast<uint16_t>(portIt->second));
Expand All @@ -90,61 +77,69 @@ Ichor::Task<tl::expected<void, Ichor::StartError>> Ichor::TcpConnectionService::
throw std::runtime_error("inet_pton invalid address for given address family (has to be ipv4-valid address)");
}

bool connected{};
while(!connected && connect(_socket, (struct sockaddr *)&address, sizeof(address)) < 0) {
ICHOR_LOG_ERROR(_logger, "[{}] connect error {}", _id, errno);
bool connected = connect(_socket, (struct sockaddr *)&address, sizeof(address)) < 0;
while(!connected && _attempts < 5) {
connected = connect(_socket, (struct sockaddr *)&address, sizeof(address)) < 0;
if(connected) {
break;
}
ICHOR_LOG_TRACE(_logger, "[{}] connect error {}", getServiceId(), errno);
if(errno == EINPROGRESS) {
while(_attempts++ >= 5) {
pollfd pfd{};
pfd.fd = _socket;
pfd.events = POLLOUT;
ret = poll(&pfd, 1, static_cast<int>(_sendTimeout));
// this is from when the socket was marked as nonblocking, don't think this is necessary anymore.
pollfd pfd{};
pfd.fd = _socket;
pfd.events = POLLOUT;
ret = poll(&pfd, 1, static_cast<int>(_sendTimeout/1'000));

if(ret < 0) {
ICHOR_LOG_ERROR(_logger, "[{}] poll error {}", getServiceId(), errno);
continue;
}

if(ret < 0) {
ICHOR_LOG_ERROR(_logger, "[{}] poll error {}", _id, errno);
continue;
}
// timeout
if(ret == 0) {
continue;
}

if(pfd.revents & POLLERR) {
ICHOR_LOG_ERROR(_logger, "[{}] POLLERR {}", getServiceId(), pfd.revents);
} else if(pfd.revents & POLLHUP) {
ICHOR_LOG_ERROR(_logger, "[{}] POLLHUP {}", getServiceId(), pfd.revents);
} else if(pfd.revents & POLLOUT) {
int connect_result{};
socklen_t result_len = sizeof(connect_result);
ret = getsockopt(_socket, SOL_SOCKET, SO_ERROR, &connect_result, &result_len);

// timeout
if(ret == 0) {
continue;
if(ret < 0) {
throw std::runtime_error("getsocketopt error: Couldn't connect");
}

if(pfd.revents & POLLERR) {
ICHOR_LOG_ERROR(_logger, "[{}] POLLERR {} {} {}", _id, pfd.revents);
} else if(pfd.revents & POLLHUP) {
ICHOR_LOG_ERROR(_logger, "[{}] POLLHUP {} {} {}", _id, pfd.revents);
} else if(pfd.revents & POLLOUT) {
int connect_result{};
socklen_t result_len = sizeof(connect_result);
ret = getsockopt(_socket, SOL_SOCKET, SO_ERROR, &connect_result, &result_len);

if(ret < 0) {
throw std::runtime_error("getsocketopt error: Couldn't connect");
}

// connect failed, retry
if(connect_result < 0) {
break;
}
connected = true;
// connect failed, retry
if(connect_result < 0) {
ICHOR_LOG_ERROR(_logger, "[{}] POLLOUT {} {}", getServiceId(), pfd.revents, connect_result);
break;
}
connected = true;
break;
}
} else if(errno == EISCONN) {
connected = true;
break;
} else if(errno == EALREADY) {
std::this_thread::sleep_for(std::chrono::microseconds(_sendTimeout));
} else {
_attempts++;
}

// we don't want to increment attempts in the EINPROGRESS case, but we do want to check it here
if(_attempts >= 5) {
throw std::runtime_error("Couldn't connect");
}
}

auto *ip = ::inet_ntoa(address.sin_addr);
ICHOR_LOG_TRACE(_logger, "[{}] Starting TCP connection for {}:{}", _id, ip, ::ntohs(address.sin_port));

if(!connected) {
ICHOR_LOG_ERROR(_logger, "[{}] Couldn't start TCP connection for {}:{}", getServiceId(), ip, ::ntohs(address.sin_port));
GetThreadLocalEventQueue().pushEvent<StopServiceEvent>(getServiceId(), getServiceId(), true);
co_return tl::unexpected(StartError::FAILED);
}
ICHOR_LOG_DEBUG(_logger, "[{}] Starting TCP connection for {}:{}", getServiceId(), ip, ::ntohs(address.sin_port));
}

_timer = &_timerFactory->createTimer();
Expand All @@ -160,6 +155,7 @@ Ichor::Task<tl::expected<void, Ichor::StartError>> Ichor::TcpConnectionService::

Ichor::Task<void> Ichor::TcpConnectionService::stop() {
_quit = true;
ICHOR_LOG_INFO(_logger, "[{}] stopping service", getServiceId());

if(_socket >= 0) {
::shutdown(_socket, SHUT_RDWR);
Expand Down Expand Up @@ -189,12 +185,13 @@ Ichor::Task<tl::expected<void, Ichor::IOError>> Ichor::TcpConnectionService::sen
size_t sent_bytes = 0;

if(_quit) {
ICHOR_LOG_TRACE(_logger, "[{}] quitting, no send", _id);
ICHOR_LOG_TRACE(_logger, "[{}] quitting, no send", getServiceId());
co_return tl::unexpected(IOError::SERVICE_QUITTING);
}

while(sent_bytes < msg.size()) {
auto ret = ::send(_socket, msg.data() + sent_bytes, msg.size() - sent_bytes, MSG_NOSIGNAL);
ICHOR_LOG_TRACE(_logger, "[{}] queued sending {} bytes, errno = {}", getServiceId(), ret, errno);

if(ret < 0) {
co_return tl::unexpected(IOError::FAILED);
Expand All @@ -208,7 +205,7 @@ Ichor::Task<tl::expected<void, Ichor::IOError>> Ichor::TcpConnectionService::sen

Ichor::Task<tl::expected<void, Ichor::IOError>> Ichor::TcpConnectionService::sendAsync(std::vector<std::vector<uint8_t>> &&msgs) {
if(_quit) {
ICHOR_LOG_TRACE(_logger, "[{}] quitting, no send", _id);
ICHOR_LOG_TRACE(_logger, "[{}] quitting, no send", getServiceId());
co_return tl::unexpected(IOError::SERVICE_QUITTING);
}

Expand All @@ -217,6 +214,7 @@ Ichor::Task<tl::expected<void, Ichor::IOError>> Ichor::TcpConnectionService::sen

while(sent_bytes < msg.size()) {
auto ret = ::send(_socket, msg.data() + sent_bytes, msg.size() - sent_bytes, 0);
ICHOR_LOG_TRACE(_logger, "[{}] queued sending {} bytes", getServiceId(), ret);

if(ret < 0) {
co_return tl::unexpected(IOError::FAILED);
Expand Down Expand Up @@ -253,26 +251,33 @@ void Ichor::TcpConnectionService::setReceiveHandler(std::function<void(std::span
void Ichor::TcpConnectionService::recvHandler() {
ScopeGuard sg{[this]() {
if(!_quit) {
_timer->startTimer();
if(!_timer->startTimer()) {
GetThreadLocalEventQueue().pushEvent<RunFunctionEvent>(getServiceId(), [this]() {
if(!_timer->startTimer()) {
std::terminate();
}
});
}
} else {
ICHOR_LOG_TRACE(_logger, "[{}] quitting, no push", _id);
ICHOR_LOG_TRACE(_logger, "[{}] quitting, no push", getServiceId());
}
}};
std::vector<uint8_t> msg{};
int64_t ret{};
ssize_t ret{};
{
std::array<uint8_t, 1024> buf;
std::array<uint8_t, 4096> buf;
do {
ret = recv(_socket, buf.data(), buf.size(), 0);
ret = recv(_socket, buf.data(), buf.size(), MSG_DONTWAIT);
if (ret > 0) {
auto data = std::span<uint8_t const>{reinterpret_cast<uint8_t const*>(buf.data()), static_cast<decltype(buf.size())>(ret)};
msg.insert(msg.end(), data.begin(), data.end());
}
} while (ret > 0 && !_quit);
}
ICHOR_LOG_TRACE(_logger, "[{}] last received {} bytes, msg size = {}, errno = {}", getServiceId(), ret, msg.size(), errno);

if (_quit) {
ICHOR_LOG_TRACE(_logger, "[{}] quitting", _id);
ICHOR_LOG_TRACE(_logger, "[{}] quitting", getServiceId());
return;
}

Expand All @@ -286,6 +291,7 @@ void Ichor::TcpConnectionService::recvHandler() {

if(ret == 0) {
// closed connection
ICHOR_LOG_INFO(_logger, "[{}] peer closed connection", getServiceId());
GetThreadLocalEventQueue().pushEvent<StopServiceEvent>(getServiceId(), getServiceId(), true);
return;
}
Expand All @@ -294,7 +300,7 @@ void Ichor::TcpConnectionService::recvHandler() {
if(errno == EAGAIN) {
return;
}
ICHOR_LOG_ERROR(_logger, "[{}] Error receiving from socket: {}", _id, errno);
ICHOR_LOG_ERROR(_logger, "[{}] Error receiving from socket: {}", getServiceId(), errno);
GetThreadLocalEventQueue().pushEvent<StopServiceEvent>(getServiceId(), getServiceId(), true);
return;
}
Expand Down
1 change: 0 additions & 1 deletion src/services/network/tcp/TcpHostService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ Ichor::AsyncGenerator<Ichor::IchorBehaviour> Ichor::TcpHostService::handleEvent(
props.emplace("Priority", Ichor::make_any<uint64_t>(_priority));
props.emplace("Socket", Ichor::make_any<int>(evt.socket));
props.emplace("TimeoutSendUs", Ichor::make_any<int64_t>(_sendTimeout));
props.emplace("TimeoutRecvUs", Ichor::make_any<int64_t>(_recvTimeout));
_connections.emplace_back(GetThreadLocalManager().template createServiceManager<TcpConnectionService, IConnectionService>(std::move(props))->getServiceId());

co_return {};
Expand Down
2 changes: 1 addition & 1 deletion src/services/timer/Timer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ bool Ichor::Timer::startTimer(bool fireImmediately) {
}
std::unique_lock l{_m};
INTERNAL_IO_DEBUG("timer {} for {} startTimer({}) {} {}", _timerId, _requestingServiceId, fireImmediately, _state, _quitCbs.size());
if(_state == TimerState::STOPPED) {
if(_state == TimerState::STOPPED || _state == TimerState::STOPPING) {
l.unlock();
if(_eventInsertionThread && _eventInsertionThread->joinable()) {
_eventInsertionThread->join();
Expand Down
16 changes: 8 additions & 8 deletions test/TcpTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TEST_CASE("TcpTests") {
auto queue = std::make_unique<QIMPL>(500, true);
#endif
ServiceIdType tcpClientId;
evtGate = false;
evtGate = 0;

std::thread t([&]() {
#ifdef TEST_URING
Expand Down Expand Up @@ -145,16 +145,16 @@ TEST_CASE("TcpTests") {
_evt = std::make_unique<Ichor::AsyncManualResetEvent>();
auto queue = std::make_unique<QIMPL>(true);
ServiceIdType tcpClientId;
evtGate = false;
evtGate = 0;

std::thread t([&]() {
#ifdef TEST_URING
REQUIRE(queue->createEventLoop());
#endif
auto &dm = queue->createManager();
uint64_t priorityToEnsureHostStartingFirst = 51;
dm.createServiceManager<CoutFrameworkLogger, IFrameworkLogger>(Properties{{"DefaultLogLevel", Ichor::make_any<LogLevel>(LogLevel::LOG_TRACE)}}, priorityToEnsureHostStartingFirst);
dm.createServiceManager<LoggerFactory<CoutLogger>, ILoggerFactory>(Properties{{"DefaultLogLevel", Ichor::make_any<LogLevel>(LogLevel::LOG_TRACE)}}, priorityToEnsureHostStartingFirst);
dm.createServiceManager<CoutFrameworkLogger, IFrameworkLogger>(Properties{{"DefaultLogLevel", Ichor::make_any<LogLevel>(LogLevel::LOG_DEBUG)}}, priorityToEnsureHostStartingFirst);
dm.createServiceManager<LoggerFactory<CoutLogger>, ILoggerFactory>(Properties{{"DefaultLogLevel", Ichor::make_any<LogLevel>(LogLevel::LOG_DEBUG)}}, priorityToEnsureHostStartingFirst);
dm.createServiceManager<HOSTIMPL, IHostService>(Properties{{"Address", Ichor::make_any<std::string>("127.0.0.1"s)}, {"Port", Ichor::make_any<uint16_t>(static_cast<uint16_t>(8001))}, {"BufferEntries", Ichor::make_any<uint32_t>(static_cast<uint16_t>(16))}, {"BufferEntrySize", Ichor::make_any<uint32_t>(static_cast<uint16_t>(16'384))}}, priorityToEnsureHostStartingFirst);
dm.createServiceManager<ClientFactory<CONNIMPL>, IClientFactory>();
#ifndef TEST_URING
Expand Down Expand Up @@ -221,7 +221,7 @@ TEST_CASE("TcpTests") {
_evt = std::make_unique<Ichor::AsyncManualResetEvent>();
auto queue = std::make_unique<QIMPL>(true);
ServiceIdType tcpClientId;
evtGate = false;
evtGate = 0;

std::thread t([&]() {
#ifdef TEST_URING
Expand Down Expand Up @@ -308,7 +308,7 @@ TEST_CASE("TcpTests") {
_evt = std::make_unique<Ichor::AsyncManualResetEvent>();
auto queue = std::make_unique<QIMPL>(true);
ServiceIdType tcpClientId;
evtGate = false;
evtGate = 0;

std::thread t([&]() {
#ifdef TEST_URING
Expand Down Expand Up @@ -384,7 +384,7 @@ TEST_CASE("TcpTests") {
_evt = std::make_unique<Ichor::AsyncManualResetEvent>();
auto queue = std::make_unique<QIMPL>(true);
ServiceIdType tcpClientId;
evtGate = false;
evtGate = 0;

std::thread t([&]() {
#ifdef TEST_URING
Expand Down Expand Up @@ -467,7 +467,7 @@ TEST_CASE("TcpTests") {
_evt = std::make_unique<Ichor::AsyncManualResetEvent>();
auto queue = std::make_unique<QIMPL>(true);
ServiceIdType tcpClientId;
evtGate = false;
evtGate = 0;

std::thread t([&]() {
#ifdef TEST_URING
Expand Down
Loading

0 comments on commit e1f832d

Please sign in to comment.