Skip to content

Commit

Permalink
TgBot++: socket: Update Selector class usage, implement connection ti…
Browse files Browse the repository at this point in the history
…meouts
  • Loading branch information
Royna2544 committed Jul 2, 2024
1 parent 59e27cf commit 72f65ff
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 21 deletions.
1 change: 1 addition & 0 deletions cmake/tgbotsocket.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ if (UNIX)
${SOCKET_SRC_INTERFACE}/impl/local/SocketPosixLocal.cpp
${SOCKET_SRC_INTERFACE}/impl/inet/SocketPosixIPv4.cpp
${SOCKET_SRC_INTERFACE}/impl/inet/SocketPosixIPv6.cpp
${SOCKET_SRC_INTERFACE}/impl/helper/SocketHelperPosix.cpp
src/socket/selector/SelectorPosixPoll.cpp
src/socket/selector/SelectorPosixSelect.cpp
src/socket/selector/SelectorUnix.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/command_modules/ibash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct InteractiveBashContext : BotClassBase {
kRun = false;
}
}
});
}, Selector::Mode::READ);
while (kRun) {
switch (selector.poll()) {
case Selector::SelectorPollResult::OK:
Expand Down
2 changes: 2 additions & 0 deletions src/socket/TgBotSocketClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ int main(int argc, char** argv) {
}

SocketClientWrapper backend(SocketInterfaceBase::LocalHelper::getSocketPath());
backend->options.use_connect_timeout.set(true);
backend->options.connect_timeout.set(3s);
auto handle = backend->createClientSocket();

if (handle) {
Expand Down
11 changes: 11 additions & 0 deletions src/socket/interface/SocketBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@

#include <SocketDescriptor_defs.hpp>
#include <TgBotSocket_Export.hpp>
#include <chrono>
#include <cstddef>
#include <filesystem>
#include <functional>
#include <optional>
#include <string>

using std::chrono_literals::operator""s;

struct SocketConnContext {
socket_handle_t cfd{}; // connection socket file descriptor
SharedMalloc addr; // struct sockaddr_*'s address
Expand Down Expand Up @@ -213,6 +216,9 @@ struct SocketInterfaceBase {
// Default constructor
Option() = default;

// Option with default value
explicit Option(T defaultValue) : data(defaultValue) {}

// Function to set the data
void set(T dataIn) { data = dataIn; }

Expand Down Expand Up @@ -242,6 +248,11 @@ struct SocketInterfaceBase {
Option<int> port;
// Option to specify whether to use UDP for socket operations
Option<bool> use_udp;
// Option to specify whether to use connection timeouts for client
Option<bool> use_connect_timeout;
// Option to specify the timeout for socket operations
// Used if use_connect_timeout is true
Option<std::chrono::seconds> connect_timeout{10s};
} options;

protected:
Expand Down
5 changes: 3 additions & 2 deletions src/socket/interface/impl/SocketPosix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <cstring>
#include <socket/selector/SelectorPosix.hpp>
#include "socket/selector/Selectors.hpp"

void SocketInterfaceUnix::startListening(socket_handle_t handle,
const listener_callback_t onNewData) {
Expand Down Expand Up @@ -37,7 +38,7 @@ void SocketInterfaceUnix::startListening(socket_handle_t handle,
PLOG(ERROR) << "Reading data from forcestop fd";
}
should_break = true;
});
}, Selector::Mode::READ);
selector.add(handle, [handle, this, &should_break, onNewData] {
struct sockaddr addr {};
socklen_t len = sizeof(addr);
Expand All @@ -51,7 +52,7 @@ void SocketInterfaceUnix::startListening(socket_handle_t handle,
should_break = onNewData(ctx);
closeSocketHandle(cfd);
}
});
}, Selector::Mode::READ);
while (!should_break) {
switch (selector.poll()) {
case Selector::PollResult::FAILED:
Expand Down
21 changes: 20 additions & 1 deletion src/socket/interface/impl/SocketPosix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <SocketBase.hpp>

#include "SharedMalloc.hpp"
#include "SocketDescriptor_defs.hpp"

struct SocketInterfaceUnix : SocketInterfaceBase {
bool isValidSocketHandle(socket_handle_t handle) override {
Expand All @@ -23,9 +24,27 @@ struct SocketInterfaceUnix : SocketInterfaceBase {
std::optional<SharedMalloc> readFromSocket(SocketConnContext context,
buffer_len_t length) override;

SocketInterfaceUnix() = default;
SocketInterfaceUnix() : posixHelper(this) {}
~SocketInterfaceUnix() override = default;

struct PosixHelper {
// SOCK_DGRAM or SOCK_STREAM?
int getSocketType();
// Connection timeout for clients are enabled?
bool connectionTimeoutEnabled();

// Handle connection timeout. Specific works before connect() in client
static void handleConnectTimeoutPre(socket_handle_t socket);
// Handle connection timeout. Specific works after connect() in client
bool handleConnectTimeoutPost(socket_handle_t socket);

explicit PosixHelper(SocketInterfaceUnix* _interface)
: interface(_interface) {}

private:
SocketInterfaceUnix* interface;
} posixHelper;

protected:
Pipe kListenTerminate{};
static void bindToInterface(const socket_handle_t sock,
Expand Down
3 changes: 2 additions & 1 deletion src/socket/interface/impl/SocketWindows.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "SocketDescriptor_defs.hpp"
#include "helper/HelperWindows.hpp"
#include "socket/selector/Selectors.hpp"

std::string SocketInterfaceWindows::WSALastErrorStr() {
char *s = nullptr;
Expand Down Expand Up @@ -51,7 +52,7 @@ void SocketInterfaceWindows::startListening(
should_break = onNewData(ctx);
closesocket(cfd);
}
});
}, Selector::Mode::READ);
while (!should_break && kRun) {
switch (selector.poll()) {
case Selector::SelectorPollResult::FAILED:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include <absl/log/log.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <sys/socket.h>

#include <array>
#include <cstddef>
#include <functional>
#include <socket/selector/SelectorPosix.hpp>
#include <string_view>

#include "SocketBase.hpp"
Expand Down Expand Up @@ -125,10 +127,19 @@ void forEachINetAddress(

constexpr std::string_view kLocalInterface = "lo";

inline int getSocketType(const SocketInterfaceBase* base) {
if (static_cast<bool>(base->options.use_udp) &&
base->options.use_udp.get()) {
return SOCK_DGRAM;
template <int flag, bool add>
void setSocketFlags(socket_handle_t handle) {
int flags = fcntl(handle, F_GETFL, 0);
if (flags == -1) {
PLOG(ERROR) << "fcntl(F_GETFL) failed";
return;
}
return SOCK_STREAM;
}
if constexpr (add) {
flags |= flag;
} else {
flags &= ~flag;
}
if (fcntl(handle, F_SETFL, flags) == -1) {
PLOG(ERROR) << "fcntl(F_SETFL) failed";
}
}
54 changes: 54 additions & 0 deletions src/socket/interface/impl/helper/SocketHelperPosix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include <impl/SocketPosix.hpp>

#include "HelperPosix.hpp"

int SocketInterfaceUnix::PosixHelper::getSocketType() {
if (static_cast<bool>(interface->options.use_udp) &&
interface->options.use_udp.get()) {
return SOCK_DGRAM;
}
return SOCK_STREAM;
}

bool SocketInterfaceUnix::PosixHelper::connectionTimeoutEnabled() {
return static_cast<bool>(interface->options.use_connect_timeout) &&
interface->options.use_connect_timeout.get();
}

void SocketInterfaceUnix::PosixHelper::handleConnectTimeoutPre(
socket_handle_t socket) {
setSocketFlags<O_NONBLOCK, true>(socket);
}

bool SocketInterfaceUnix::PosixHelper::handleConnectTimeoutPost(
socket_handle_t socket) {
UnixSelector selector;

LOG(INFO) << "Connecting timeout mode enabled";
selector.add(socket, []() {}, Selector::Mode::WRITE);
selector.enableTimeout(true);
selector.setTimeout(interface->options.connect_timeout.get());
switch (selector.poll()) {
case Selector::PollResult::OK:
break;
case Selector::PollResult::FAILED:
case Selector::PollResult::TIMEOUT:
LOG(ERROR) << "Connecting timeout";
return false;
}
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(socket, SOL_SOCKET, SO_ERROR, &error, &len) == -1) {
PLOG(ERROR) << "getsockopt(SO_ERROR) failed";
return false;
}

if (error != 0) {
LOG(ERROR) << "Failed to connect: " << strerror(error);
return false;
}

LOG(INFO) << "Connected";
setSocketFlags<O_NONBLOCK, false>(socket);
return true;
}
18 changes: 14 additions & 4 deletions src/socket/interface/impl/inet/SocketPosixIPv4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <impl/SocketPosix.hpp>

#include "HelperPosix.hpp"
#include "../helper/HelperPosix.hpp"
#include "SocketBase.hpp"

std::optional<socket_handle_t> SocketInterfaceUnixIPv4::createServerSocket() {
Expand All @@ -15,7 +15,7 @@ std::optional<socket_handle_t> SocketInterfaceUnixIPv4::createServerSocket() {
struct sockaddr_in name {};
auto* _name = reinterpret_cast<struct sockaddr*>(&name);

socket_handle_t sfd = socket(AF_INET, getSocketType(this), 0);
socket_handle_t sfd = socket(AF_INET, posixHelper.getSocketType(), 0);

if (!isValidSocketHandle(sfd)) {
PLOG(ERROR) << "Failed to create socket";
Expand Down Expand Up @@ -58,7 +58,7 @@ std::optional<SocketConnContext> SocketInterfaceUnixIPv4::createClientSocket() {
struct sockaddr_in name {};
const auto* _name = reinterpret_cast<struct sockaddr*>(&name);

ctx.cfd = socket(AF_INET, getSocketType(this), 0);
ctx.cfd = socket(AF_INET, posixHelper.getSocketType(), 0);
if (!isValidSocketHandle(ctx.cfd)) {
PLOG(ERROR) << "Failed to create socket";
return std::nullopt;
Expand All @@ -67,11 +67,21 @@ std::optional<SocketConnContext> SocketInterfaceUnixIPv4::createClientSocket() {
name.sin_family = AF_INET;
name.sin_port = htons(helper.inet.getPortNum());
inet_pton(AF_INET, options.address.get().c_str(), &name.sin_addr);
if (connect(ctx.cfd, _name, sizeof(name)) != 0) {

if (posixHelper.connectionTimeoutEnabled()) {
posixHelper.handleConnectTimeoutPre(ctx.cfd);
}

// Blocking sockets wont return EINPROGRESS anyway...
if (connect(ctx.cfd, _name, sizeof(name)) != 0 && errno != EINPROGRESS) {
PLOG(ERROR) << "Failed to connect to socket";
closeSocketHandle(ctx.cfd);
return std::nullopt;
}
if (posixHelper.connectionTimeoutEnabled() && !posixHelper.handleConnectTimeoutPost(ctx.cfd)) {
closeSocketHandle(ctx.cfd);
return std::nullopt;
}
ctx.addr.assignFrom(name);
return ctx;
}
Expand Down
15 changes: 12 additions & 3 deletions src/socket/interface/impl/inet/SocketPosixIPv6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
#include <functional>
#include <impl/SocketPosix.hpp>

#include "HelperPosix.hpp"
#include "../helper/HelperPosix.hpp"

std::optional<socket_handle_t> SocketInterfaceUnixIPv6::createServerSocket() {
socket_handle_t ret = kInvalidFD;
bool iface_done = false;
struct sockaddr_in6 name {};
auto* _name = reinterpret_cast<struct sockaddr*>(&name);
socket_handle_t sfd = socket(AF_INET6, getSocketType(this), 0);
socket_handle_t sfd = socket(AF_INET6, posixHelper.getSocketType(), 0);

if (!isValidSocketHandle(sfd)) {
PLOG(ERROR) << "Failed to create socket";
Expand Down Expand Up @@ -58,7 +58,7 @@ std::optional<SocketConnContext> SocketInterfaceUnixIPv6::createClientSocket() {
struct sockaddr_in6 name {};
auto* _name = reinterpret_cast<struct sockaddr*>(&name);

ctx.cfd = socket(AF_INET6, getSocketType(this), 0);
ctx.cfd = socket(AF_INET6, posixHelper.getSocketType(), 0);
if (!isValidSocketHandle(ctx.cfd)) {
PLOG(ERROR) << "Failed to create socket";
return std::nullopt;
Expand All @@ -67,11 +67,20 @@ std::optional<SocketConnContext> SocketInterfaceUnixIPv6::createClientSocket() {
name.sin6_family = AF_INET6;
name.sin6_port = htons(helper.inet.getPortNum());
inet_pton(AF_INET6, options.address.get().c_str(), &name.sin6_addr);

if (posixHelper.connectionTimeoutEnabled()) {
posixHelper.handleConnectTimeoutPre(ctx.cfd);
}

if (connect(ctx.cfd, _name, sizeof(name)) != 0) {
PLOG(ERROR) << "Failed to connect to socket";
closeSocketHandle(ctx.cfd);
return std::nullopt;
}
if (posixHelper.connectionTimeoutEnabled() && !posixHelper.handleConnectTimeoutPost(ctx.cfd)) {
closeSocketHandle(ctx.cfd);
return std::nullopt;
}
ctx.addr.assignFrom(name);
return ctx;
}
Expand Down
6 changes: 3 additions & 3 deletions src/socket/interface/impl/local/SocketPosixLocal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
#include <optional>

#include "SocketBase.hpp"
#include "../inet/HelperPosix.hpp"
#include "../helper/HelperPosix.hpp"

bool SocketInterfaceUnixLocal::createLocalSocket(SocketConnContext *ctx) {
ctx->cfd = socket(AF_UNIX, getSocketType(this), 0);
ctx->cfd = socket(AF_UNIX, posixHelper.getSocketType(), 0);
if (ctx->cfd < 0) {
PLOG(ERROR) << "Failed to create socket";
return false;
Expand All @@ -26,7 +26,7 @@ std::optional<socket_handle_t> SocketInterfaceUnixLocal::createServerSocket() {
SocketConnContext ret = SocketConnContext::create<sockaddr_un>();
const auto *_name = reinterpret_cast<struct sockaddr *>(ret.addr.get());

LOG(INFO) << "Creating socket at " << LocalHelper::getSocketPath().string();
LOG(INFO) << "Creating socket at " << options.address.get();
if (!createLocalSocket(&ret)) {
return std::nullopt;
}
Expand Down

0 comments on commit 72f65ff

Please sign in to comment.