diff --git a/.gitmodules b/.gitmodules index b716f00a..33d33219 100644 --- a/.gitmodules +++ b/.gitmodules @@ -19,6 +19,3 @@ [submodule "src/third-party/googletest"] path = src/third-party/googletest url = https://github.com/google/googletest/ -[submodule "src/hash/third-party/sha-2"] - path = src/hash/sha-2 - url = https://github.com/amosnier/sha-2/ diff --git a/CMakeLists.txt b/CMakeLists.txt index d77d706e..7291efe5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,7 +94,7 @@ endfunction() ## Compiler specific settings set(GLOBAL_COMPILE_OPTIONS) -set(GLOBAL_DEFINITIONS __TGBOT__) +set(GLOBAL_DEFINITIONS) set(GLOBAL_INCLUDE_DIRS ${CMAKE_SOURCE_DIR}/src/include ${CMAKE_SOURCE_DIR}/src/) # Test proper c++20 jthread, stop_token @@ -424,7 +424,6 @@ endif() add_subdirectory(src/api) add_subdirectory(src/database) add_subdirectory(src/random) -add_subdirectory(src/hash) add_subdirectory(src/imagep) add_subdirectory(src/stringres) add_subdirectory(src/logging) diff --git a/TODO b/TODO index caffbf87..26637ed0 100644 --- a/TODO +++ b/TODO @@ -1,4 +1,5 @@ - Handle forked process in forkandrun when the main process exits due to network. - Better installer on Windows. - dependencies resolving in kernelbuilder -- optional kernelbuilder \ No newline at end of file +- optional kernelbuilder +- SocketDataHandler JSON version test \ No newline at end of file diff --git a/src/api/TgBotApiImpl.cpp b/src/api/TgBotApiImpl.cpp index cee788d4..f3bcd7c3 100644 --- a/src/api/TgBotApiImpl.cpp +++ b/src/api/TgBotApiImpl.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include "tgbot/net/CurlHttpClient.h" @@ -266,6 +267,7 @@ void TgBotApiImpl::startPoll() { // Start the long poll loop. while (!SignalHandler::isSignaled()) { longPoll->start(); + std::this_thread::sleep_for(std::chrono::seconds(1)); } } diff --git a/src/database/CMakeLists.txt b/src/database/CMakeLists.txt index e4e6c973..48603439 100644 --- a/src/database/CMakeLists.txt +++ b/src/database/CMakeLists.txt @@ -74,6 +74,7 @@ add_my_executable( DBImpl Socket DBLoading + JsonCpp::JsonCpp RELATION Socket ) ##################################################################### \ No newline at end of file diff --git a/src/database/utils/SendMediaToChat.cc b/src/database/utils/SendMediaToChat.cc index f53eed3d..3fd6d776 100644 --- a/src/database/utils/SendMediaToChat.cc +++ b/src/database/utils/SendMediaToChat.cc @@ -1,15 +1,20 @@ #include #include +#include +#include #include #include +#include +#include #include #include #include -#include #include #include + #include "ConfigManager.hpp" +#include "TgBotSocket_Export.hpp" [[noreturn]] static void usage(const char* argv0, const int exitCode) { std::cerr << "Usage: " << argv0 << " " @@ -17,6 +22,30 @@ exit(exitCode); } +std::optional parseAndCheck( + const void* buf, TgBotSocket::Packet::Header::length_type length, + const std::initializer_list nodes) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(std::string(static_cast(buf), length), + root)) { + LOG(WARNING) << "Failed to parse json: " + << reader.getFormattedErrorMessages(); + return std::nullopt; + } + if (!root.isObject()) { + LOG(WARNING) << "Expected an object in json"; + return std::nullopt; + } + for (const auto& node : nodes) { + if (!root.isMember(node)) { + LOG(WARNING) << fmt::format("Missing node '{}' in json", node); + return std::nullopt; + } + } + return root; +} + int main(int argc, char** argv) { ChatId chatId = 0; TgBotSocket::data::SendFileToChatId data = {}; @@ -55,14 +84,63 @@ int main(int argc, char** argv) { LOG(INFO) << "Found, sending (fileid " << info->mediaId << ") to chat " << chatId; } - copyTo(data.filePath, info->mediaId.c_str()); + copyTo(data.filePath, info->mediaId); data.chat = chatId; data.fileType = TgBotSocket::data::FileType::TYPE_DOCUMENT; - struct TgBotSocket::Packet pkt( - TgBotSocket::Command::CMD_SEND_FILE_TO_CHAT_ID, data); SocketClientWrapper wrapper; - wrapper.connect(TgBotSocket::Context::kTgBotHostPort, TgBotSocket::Context::hostPath()); - wrapper->write(pkt); - backend->unload(); + if (wrapper.connect(TgBotSocket::Context::kTgBotHostPort, + TgBotSocket::Context::hostPath())) { + using namespace TgBotSocket; + DLOG(INFO) << "Connected to server"; + Packet openSession = createPacket(Command::CMD_OPEN_SESSION, nullptr, 0, + PayloadType::Binary, {}); + wrapper->write(openSession); + DLOG(INFO) << "Wrote open session packet"; + auto openSessionAck = + TgBotSocket::readPacket(wrapper.chosen_interface()); + if (!openSessionAck || + openSessionAck->header.cmd != Command::CMD_OPEN_SESSION_ACK) { + LOG(ERROR) << "Failed to open session"; + return EXIT_FAILURE; + } + auto _root = parseAndCheck(openSessionAck->data.get(), + openSessionAck->data.size(), + {"session_token", "expiration_time"}); + if (!_root) { + LOG(ERROR) << "Invalid open session ack json"; + return EXIT_FAILURE; + } + auto root = *_root; + LOG(INFO) << "Opened session. Token: " << root["session_token"] + << " expiration_time: " << root["expiration_time"]; + + std::string session_token_str = root["session_token"].asString(); + Packet::Header::session_token_type session_token{}; + copyTo(session_token, session_token_str); + auto pkt = + createPacket(TgBotSocket::Command::CMD_SEND_FILE_TO_CHAT_ID, &data, + sizeof(data), PayloadType::Binary, session_token); + if (!wrapper->write(pkt)) { + LOG(ERROR) << "Failed to write send file to chat id packet"; + backend->unload(); + return EXIT_FAILURE; + } + auto result = TgBotSocket::readPacket(wrapper.chosen_interface()); + if (!result || result->header.cmd != Command::CMD_GENERIC_ACK) { + LOG(ERROR) << "Failed to send file to chat id"; + backend->unload(); + return EXIT_FAILURE; + } + TgBotSocket::callback::GenericAck genericAck; + result->data.assignTo(genericAck); + if (genericAck.result != TgBotSocket::callback::AckType::SUCCESS) { + LOG(ERROR) << "Failed to send file to chat id: " + << genericAck.error_msg.data(); + backend->unload(); + return EXIT_FAILURE; + } + DLOG(INFO) << "File sent successfully"; + backend->unload(); + } } diff --git a/src/hash/CMakeLists.txt b/src/hash/CMakeLists.txt deleted file mode 100644 index 2a6d081e..00000000 --- a/src/hash/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_my_library( - NAME minimalhash - SRCS crc32.cpp sha256.cpp sha-2/sha-256.c - STATIC -) \ No newline at end of file diff --git a/src/hash/crc32.cpp b/src/hash/crc32.cpp deleted file mode 100644 index c3816b5c..00000000 --- a/src/hash/crc32.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "crc32.hpp" - -#include -#include -#include -#include - -template , bool> = true> -constexpr IntT reverseBits(IntT value) { - IntT result = 0; - for (int i = 0; i < std::numeric_limits::digits; ++i) { - result |= ((value >> i) & 1) - << (std::numeric_limits::digits - 1 - i); - } - return result; -} - -constexpr std::array CRC32::generateCRCTable() { - constexpr uint32_t REVERSED_GENERATOR_POLYNOMIAL = - reverseBits(GENERATOR_POLYNOMIAL); - std::array table = {}; - for (uint32_t i = 0; i < CRC32::TABLE_LENGTH; ++i) { - uint32_t crc = i; - for (int j = 0; j < std::numeric_limits::digits; ++j) { - if ((crc & 1) != 0) { - crc = (crc >> 1) ^ REVERSED_GENERATOR_POLYNOMIAL; - } else { - crc >>= 1; - } - } - table[i] = crc; - } - return table; -} - -uint32_t CRC32::compute(const uint8_t* data, std::size_t length) { - static constexpr std::array crcTable = - generateCRCTable(); - uint32_t crc = ~0U; - while ((length--) != 0) { - crc = (crc >> std::numeric_limits::digits) ^ - crcTable[(crc ^ *data++) & std::numeric_limits::max()]; - } - return ~crc; -} diff --git a/src/hash/crc32.hpp b/src/hash/crc32.hpp deleted file mode 100644 index fa0b00f5..00000000 --- a/src/hash/crc32.hpp +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -#include -#include - -class CRC32 { - static constexpr unsigned int GENERATOR_POLYNOMIAL = - 0b00000100110000010001110110110111; - static constexpr int TABLE_LENGTH = 256; - - static constexpr std::array generateCRCTable(); - - public: - using result_type = uint32_t; - - static result_type compute(const uint8_t* data, std::size_t length); -}; diff --git a/src/hash/sha-2 b/src/hash/sha-2 deleted file mode 160000 index 565f6500..00000000 --- a/src/hash/sha-2 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 565f65009bdd98267361b17d50cddd7c9beb3e6c diff --git a/src/hash/sha256.cpp b/src/hash/sha256.cpp deleted file mode 100644 index 558d3153..00000000 --- a/src/hash/sha256.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include "sha256.hpp" - -SHA256::result_type SHA256::compute(const uint8_t* data, std::size_t length) { - struct Sha_256 sha_256{}; - result_type hash {}; - sha_256_init(&sha_256, hash.data()); - sha_256_write(&sha_256, data, length); - sha_256_close(&sha_256); - return hash; -} \ No newline at end of file diff --git a/src/include/SharedMalloc.hpp b/src/include/SharedMalloc.hpp index 31367f8c..b8749851 100644 --- a/src/include/SharedMalloc.hpp +++ b/src/include/SharedMalloc.hpp @@ -1,15 +1,14 @@ #pragma once -#include +#include #include #include +#include #include #include #include -#include "trivial_helpers/_class_helper_macros.h" - #ifndef __cpp_concepts #define requires(x) #endif @@ -72,6 +71,7 @@ struct SharedMalloc { parent = std::make_shared(size); } } + explicit SharedMalloc(std::nullptr_t) : SharedMalloc() {} SharedMalloc() { parent = std::make_shared(); } template , bool> = true> @@ -79,26 +79,15 @@ struct SharedMalloc { parent = std::make_shared(sizeof(T)); assignFrom(value); } - explicit SharedMalloc(std::nullptr_t /*value*/) { - parent = std::make_shared(); - } + explicit operator bool() const { return parent->size() != 0; } bool operator!=(std::nullptr_t value) { return parent.get() != value; } - - template - explicit operator T() const { - T value; - assignTo(value); - return value; - } [[nodiscard]] size_t size() const noexcept { return parent->size(); } void resize(size_t newSize) const noexcept { parent->realloc(newSize); } private: // A fortify check. inline void validateBoundsForSize(const size_t newSize) const { - DCHECK_LE(newSize, size()) - << ": Operation size exceeds allocated memory size"; if (newSize > size()) { throw std::out_of_range( "Operation size exceeds allocated memory size"); @@ -106,7 +95,6 @@ struct SharedMalloc { } inline void offsetCheck(const size_t offset) const { - DCHECK_LE(offset, size()) << ": Offset exceeds allocated memory bounds"; if (offset > size()) { throw std::out_of_range("Offset exceeds allocated memory bounds"); } @@ -194,8 +182,10 @@ struct SharedMalloc { template requires(!std::is_pointer_v) void assignFrom(const T &ref) { - CHECK_LE(sizeof(T), size()) - << ": *this Must have bigger size than sizeof(T)"; + if (sizeof(T) > size()) { + throw std::out_of_range( + "Operation size exceeds allocated memory size"); + } assignFrom(&ref, sizeof(T)); } diff --git a/src/include/global_handlers/SpamBlock.hpp b/src/include/global_handlers/SpamBlock.hpp index 84b37465..77537250 100644 --- a/src/include/global_handlers/SpamBlock.hpp +++ b/src/include/global_handlers/SpamBlock.hpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include using TgBot::Chat; diff --git a/src/socket/CMakeLists.txt b/src/socket/CMakeLists.txt index f2c9cba7..f7f93f55 100644 --- a/src/socket/CMakeLists.txt +++ b/src/socket/CMakeLists.txt @@ -1,6 +1,9 @@ +find_package(OpenSSL REQUIRED) add_subdirectory(selector) +add_subdirectory(hash) find_package(Boost 1.70 CONFIG COMPONENTS system) + add_my_library( NAME Socket SRCS @@ -12,17 +15,15 @@ add_my_library( bot/FileHelperNew.cpp backends/ClientBackend.cpp TgBotCommandMap.cpp - PUBLIC_INC include interface - LIBS Utils minimalhash Boost::system + PUBLIC_INC interface + LIBS Utils minimalhash Boost::system OpenSSL::Crypto LIBS_WIN32 wsock32 Ws2_32 ) -find_package(ZLIB REQUIRED) - add_my_executable( NAME SocketCli SRCS TgBotSocketClient.cpp - LIBS Socket + LIBS Socket JsonCpp::JsonCpp RELATION Socket ) diff --git a/src/socket/TgBotCommandMap.hpp b/src/socket/TgBotCommandMap.hpp index f2591bd9..e049da74 100644 --- a/src/socket/TgBotCommandMap.hpp +++ b/src/socket/TgBotCommandMap.hpp @@ -5,7 +5,7 @@ #include #include -#include "include/TgBotSocket_Export.hpp" +#include "TgBotSocket_Export.hpp" template <> struct fmt::formatter : formatter { @@ -29,11 +29,14 @@ struct fmt::formatter : formatter { DEFINE_STR(CMD_UPLOAD_FILE); DEFINE_STR(CMD_DOWNLOAD_FILE); DEFINE_STR(CMD_CLIENT_MAX); - DEFINE_STR(CMD_SERVER_INTERNAL_START); + DEFINE_STR(CMD_GET_UPTIME_CALLBACK); DEFINE_STR(CMD_GENERIC_ACK); DEFINE_STR(CMD_UPLOAD_FILE_DRY); DEFINE_STR(CMD_UPLOAD_FILE_DRY_CALLBACK); DEFINE_STR(CMD_DOWNLOAD_FILE_CALLBACK); + DEFINE_STR(CMD_OPEN_SESSION); + DEFINE_STR(CMD_OPEN_SESSION_ACK); + DEFINE_STR(CMD_CLOSE_SESSION); #undef DEFINE_STR default: break; diff --git a/src/socket/TgBotSocketClient.cpp b/src/socket/TgBotSocketClient.cpp index edd41624..c26347f6 100644 --- a/src/socket/TgBotSocketClient.cpp +++ b/src/socket/TgBotSocketClient.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -78,7 +79,7 @@ std::string_view AckTypeToStr(callback::AckType type) { } // namespace -void handle_CommandPacket(SocketClientWrapper wrapper, const Packet& pkt) { +void handleCallback(SocketClientWrapper& connector, const Packet& pkt) { using callback::AckType; using callback::GenericAck; std::string resultText; @@ -106,12 +107,6 @@ void handle_CommandPacket(SocketClientWrapper wrapper, const Packet& pkt) { if (callbackData.result != AckType::SUCCESS) { LOG(ERROR) << "Reason: " << callbackData.error_msg.data(); } else { - wrapper->close(); - if (!wrapper.connect(Context::kTgBotHostPort, - Context::hostPath())) { - LOG(ERROR) << "Failed to recreate client socket"; - return; - } LOG(INFO) << "Recreated client socket"; auto params_in = callbackData.requestdata; SocketFile2DataHelper::DataFromFileParam param; @@ -121,12 +116,13 @@ void handle_CommandPacket(SocketClientWrapper wrapper, const Packet& pkt) { auto newPkt = helper .DataFromFile( - param); + param, pkt.header.session_token); LOG(INFO) << "Sending the actual file content again..."; - wrapper->write(newPkt.value()); - auto it2 = TgBotSocket::readPacket(wrapper.chosen_interface()); + connector->write(newPkt.value()); + auto it2 = + TgBotSocket::readPacket(connector.chosen_interface()); if (it2) { - handle_CommandPacket(std::move(wrapper), it2.value()); + handleCallback(connector, it2.value()); } } break; @@ -148,9 +144,114 @@ void handle_CommandPacket(SocketClientWrapper wrapper, const Packet& pkt) { } } +std::optional parseAndCheck( + const void* buf, TgBotSocket::Packet::Header::length_type length, + const std::initializer_list nodes) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(std::string(static_cast(buf), length), + root)) { + LOG(WARNING) << "Failed to parse json: " + << reader.getFormattedErrorMessages(); + return std::nullopt; + } + if (!root.isObject()) { + LOG(WARNING) << "Expected an object in json"; + return std::nullopt; + } + for (const auto& node : nodes) { + if (!root.isMember(node)) { + LOG(WARNING) << fmt::format("Missing node '{}' in json", node); + return std::nullopt; + } + } + return root; +} + +template +std::optional parseArgs(char** argv) = delete; + +template <> +std::optional parseArgs(char** argv) { + data::WriteMsgToChatId data{}; + if (!try_parse(argv[0], &data.chat)) { + return std::nullopt; + } + copyTo(data.message, argv[1]); + return data; +} + +template <> +std::optional parseArgs(char** argv) { + data::CtrlSpamBlock data; + if (parseOneEnum(&data, data::CtrlSpamBlock::MAX, argv[0], "spamblock")) { + return data; + } + return std::nullopt; +} + +template <> +std::optional parseArgs(char** argv) { + data::ObserveChatId data{}; + if (try_parse(argv[0], &data.chat) && try_parse(argv[1], &data.observe)) { + return data; + } + return std::nullopt; +} + +template <> +std::optional parseArgs(char** argv) { + data::SendFileToChatId data{}; + ChatId id; + data::FileType fileType; + if (try_parse(argv[0], &id) && + parseOneEnum(&fileType, data::FileType::TYPE_MAX, argv[1], "type")) { + data.chat = id; + data.fileType = fileType; + copyTo(data.filePath, argv[2]); + return data; + } + return std::nullopt; +} + +template <> +std::optional parseArgs(char** argv) { + data::ObserveAllChats data{}; + bool observe = false; + if (try_parse(argv[0], &observe)) { + data.observe = observe; + return data; + } + return std::nullopt; +} + +struct None {}; + +template <> +std::optional parseArgs(char** argv) { + return None{}; +} + +template <> +std::optional parseArgs(char** argv) { + SocketFile2DataHelper::DataFromFileParam params; + params.filepath = argv[0]; + params.destfilepath = argv[1]; + params.options.hash_ignore = false; //= true; + params.options.overwrite = true; + return params; +} + +template <> +std::optional parseArgs(char** argv) { + data::DownloadFile data{}; + copyTo(data.filepath, argv[0]); + copyTo(data.destfilename, argv[1]); + return data; +} + int main(int argc, char** argv) { enum Command cmd {}; - std::optional pkt; const char* exe = argv[0]; TgBot_AbslLogInit(); @@ -179,105 +280,137 @@ int main(int argc, char** argv) { usage(exe, false); } - RealFS realfs; - SocketFile2DataHelper helper(&realfs); + SocketClientWrapper backend; + if (backend.connect(Context::kTgBotHostPort, Context::hostPath())) { + DLOG(INFO) << "Connected to server"; + Packet openSession = createPacket(Command::CMD_OPEN_SESSION, nullptr, 0, + PayloadType::Binary, {}); + backend->write(openSession); + DLOG(INFO) << "Wrote open session packet"; + auto openSessionAck = + TgBotSocket::readPacket(backend.chosen_interface()); + if (!openSessionAck || + openSessionAck->header.cmd != Command::CMD_OPEN_SESSION_ACK) { + LOG(ERROR) << "Failed to open session"; + return EXIT_FAILURE; + } + auto _root = parseAndCheck(openSessionAck->data.get(), + openSessionAck->data.size(), + {"session_token", "expiration_time"}); + if (!_root) { + LOG(ERROR) << "Invalid open session ack json"; + return EXIT_FAILURE; + } + auto root = *_root; + LOG(INFO) << "Opened session. Token: " << root["session_token"] + << " expiration_time: " << root["expiration_time"]; + + std::string session_token_str = root["session_token"].asString(); + Packet::Header::session_token_type session_token{}; + copyTo(session_token, session_token_str); - switch (cmd) { - case Command::CMD_WRITE_MSG_TO_CHAT_ID: { - data::WriteMsgToChatId data{}; - ChatId id; - if (!try_parse(argv[0], &id)) { + std::optional pkt; + switch (cmd) { + case Command::CMD_WRITE_MSG_TO_CHAT_ID: { + auto args = parseArgs(argv); + if (!args) { + usage(exe, false); + return EXIT_FAILURE; + } + pkt = createPacket(cmd, &args.value(), sizeof(*args), + PayloadType::Binary, session_token); break; } - data.chat = id; - copyTo(data.message, argv[1]); - pkt = Packet(cmd, data); - break; - } - case Command::CMD_CTRL_SPAMBLOCK: { - data::CtrlSpamBlock data; - if (parseOneEnum(&data, data::CtrlSpamBlock::MAX, argv[0], - "spamblock")) { - pkt = Packet(cmd, data); + case Command::CMD_CTRL_SPAMBLOCK: { + auto args = parseArgs(argv); + if (!args) { + usage(exe, false); + return EXIT_FAILURE; + } + pkt = createPacket(cmd, &args.value(), sizeof(*args), + PayloadType::Binary, session_token); + break; } - break; - } - case Command::CMD_OBSERVE_CHAT_ID: { - data::ObserveChatId data{}; - bool observe; - ChatId id; - if (try_parse(argv[0], &id) && try_parse(argv[1], &observe)) { - data.chat = id; - data.observe = observe; - pkt = Packet(cmd, data); + case Command::CMD_OBSERVE_CHAT_ID: { + auto args = parseArgs(argv); + if (!args) { + usage(exe, false); + return EXIT_FAILURE; + } + pkt = createPacket(cmd, &args.value(), sizeof(*args), + PayloadType::Binary, session_token); + break; } - break; - } - case Command::CMD_SEND_FILE_TO_CHAT_ID: { - data::SendFileToChatId data{}; - ChatId id; - data::FileType fileType; - if (try_parse(argv[0], &id) && - parseOneEnum(&fileType, data::FileType::TYPE_MAX, argv[1], - "type")) { - data.chat = id; - data.fileType = fileType; - copyTo(data.filePath, argv[2]); - pkt = Packet(cmd, data); + case Command::CMD_SEND_FILE_TO_CHAT_ID: { + auto args = parseArgs(argv); + if (!args) { + usage(exe, false); + return EXIT_FAILURE; + } + pkt = createPacket(cmd, &args.value(), sizeof(*args), + PayloadType::Binary, session_token); + } break; + case Command::CMD_OBSERVE_ALL_CHATS: { + auto args = parseArgs(argv); + if (!args) { + usage(exe, false); + return EXIT_FAILURE; + } + pkt = createPacket(cmd, &args.value(), sizeof(*args), + PayloadType::Binary, session_token); + } break; + case Command::CMD_GET_UPTIME: { + auto args = parseArgs(argv); + if (!args) { + usage(exe, false); + return EXIT_FAILURE; + } + pkt = createPacket(cmd, &args.value(), sizeof(*args), + PayloadType::Binary, session_token); + break; + } + case Command::CMD_UPLOAD_FILE: { + RealFS realfs; + SocketFile2DataHelper helper(&realfs); + auto args = + parseArgs(argv); + if (!args) { + usage(exe, false); + return EXIT_FAILURE; + } + pkt = helper.DataFromFile< + SocketFile2DataHelper::Pass::UPLOAD_FILE_DRY>( + args.value(), session_token); + break; } - } break; - case Command::CMD_OBSERVE_ALL_CHATS: { - data::ObserveAllChats data{}; - bool observe = false; - if (try_parse(argv[0], &observe)) { - data.observe = observe; - pkt = Packet(cmd, data); + case Command::CMD_DOWNLOAD_FILE: { + auto args = parseArgs(argv); + if (!args) { + usage(exe, false); + return EXIT_FAILURE; + } + pkt = createPacket(cmd, &args.value(), sizeof(*args), + PayloadType::Binary, session_token); + break; } - } break; - case Command::CMD_GET_UPTIME: { - // Data is unused in this case - pkt = Packet(cmd, 1); - break; - } - case Command::CMD_UPLOAD_FILE: { - SocketFile2DataHelper::DataFromFileParam params; - params.filepath = argv[0]; - params.destfilepath = argv[1]; - params.options.hash_ignore = false; //= true; - params.options.overwrite = true; - pkt = - helper - .DataFromFile( - params); - break; - } - case Command::CMD_DOWNLOAD_FILE: { - data::DownloadFile data{}; - copyTo(data.filepath, argv[0]); - copyTo(data.destfilename, argv[1]); - pkt = Packet(cmd, data); - break; - } - default: - LOG(FATAL) << fmt::format("Unhandled command: {}", cmd); - }; - - if (!pkt) { - LOG(ERROR) << fmt::format("Failed parsing arguments for {}", cmd); - return EXIT_FAILURE; - } else { - pkt->header.data_checksum = pkt->crc32_function(pkt->data); - } + default: + LOG(FATAL) << fmt::format("Unhandled command: {}", cmd); + }; - SocketClientWrapper backend; - if (backend.connect(Context::kTgBotHostPort, Context::hostPath())) { backend->write(*pkt); LOG(INFO) << "Sent the command: Waiting for callback..."; auto it = TgBotSocket::readPacket(backend.chosen_interface()); if (it) { - handle_CommandPacket(std::move(backend), it.value()); + handleCallback(backend, it.value()); + } + auto closePacket = + createPacket(TgBotSocket::Command::CMD_CLOSE_SESSION, nullptr, 0, + PayloadType::Binary, session_token); + if (!backend->write(closePacket)) { + LOG(ERROR) << "Failed to close session"; + return EXIT_FAILURE; } } - return static_cast(!pkt.has_value()); + return EXIT_SUCCESS; } diff --git a/src/socket/include/TgBotSocket_Export.hpp b/src/socket/TgBotSocket_Export.hpp similarity index 76% rename from src/socket/include/TgBotSocket_Export.hpp rename to src/socket/TgBotSocket_Export.hpp index ce3f7d24..002729f3 100644 --- a/src/socket/include/TgBotSocket_Export.hpp +++ b/src/socket/TgBotSocket_Export.hpp @@ -1,22 +1,17 @@ #pragma once // A header export for the TgBot's socket connection +#include + +#include #include #include #include #include #include -#ifdef __TGBOT__ -#include - -#include -#else -#include "../../include/SharedMalloc.hpp" -#include "../../include/Types.h" -#endif -#include "../../hash/crc32.hpp" -#include "../../hash/sha256.hpp" +#include "hash/hmac.hpp" +#include "hash/sha256.hpp" template inline bool arraycmp(const std::array& lhs, @@ -28,8 +23,8 @@ inline bool arraycmp(const std::array& lhs, } template -inline void copyTo(std::array& arr_in, const char* buf) { - strncpy(arr_in.data(), buf, size - 1); +inline void copyTo(std::array& arr_in, const std::string_view buf) { + strncpy(arr_in.data(), buf.data(), std::min(size - 1, buf.size())); arr_in[size - 1] = '\0'; } @@ -63,6 +58,9 @@ enum class Command : std::int32_t { CMD_UPLOAD_FILE_DRY, CMD_UPLOAD_FILE_DRY_CALLBACK, CMD_DOWNLOAD_FILE_CALLBACK, + CMD_OPEN_SESSION, + CMD_OPEN_SESSION_ACK, + CMD_CLOSE_SESSION, CMD_MAX, }; @@ -86,7 +84,6 @@ struct alignas(ALIGNMENT) Packet { * Header contains the magic value, command, and the size of the data */ struct alignas(ALIGNMENT) Header { - using length_type = uint64_t; constexpr static int64_t MAGIC_VALUE_BASE = 0xDEADFACE; // Version number, to be increased on breaking changes // 1: Initial version @@ -101,55 +98,37 @@ struct alignas(ALIGNMENT) Packet { // 9: Remove padding objects // 10: Alignments fixing for Python compliance, add INVALID_CMD at 0 // 11: Remove CMD_DELETE_CONTROLLER_BY_ID, add payload type to header - constexpr static int DATA_VERSION = 11; + // 12: Use OpenSSL's HMAC and AES-GCM algorithm encryption with nounces. + // Also use CMD_OPEN_SESSION CMD_CLOSE_SESSION for session based + // encryption + constexpr static int DATA_VERSION = 12; constexpr static int64_t MAGIC_VALUE = MAGIC_VALUE_BASE + DATA_VERSION; + using length_type = uint32_t; + using nounce_type = uint64_t; + using hmac_type = HMAC::result_type; + + // Using AES-GCM + constexpr static int IV_LENGTH = 12; + constexpr static int TAG_LENGTH = 16; + constexpr static int SESSION_TOKEN_LENGTH = 32; + + using session_token_type = std::array; + using init_vector_type = std::array; + int64_t magic = MAGIC_VALUE; ///< Magic value to verify the packet Command cmd{}; ///< Command to be executed - ///< Type of payload in the packet - PayloadType data_type{}; - ///< Size of the data in the packet - length_type data_size{}; - ///< Checksum of the packet data - CRC32::result_type data_checksum{}; + PayloadType data_type{}; ///< Type of payload in the packet + length_type data_size{}; ///< Size of the data in the packet + session_token_type session_token{}; ///< Session token + nounce_type nonce{}; ///< Nonce (Epoch timestamp is used) + hmac_type hmac{}; ///< HMAC data + init_vector_type init_vector{}; ///< Initialization vector for AES-GCM }; static_assert(offsetof(Header, magic) == 0); Header header{}; SharedMalloc data; - - explicit Packet(Header::length_type length) : data(length) { - header.magic = Header::MAGIC_VALUE; - header.data_size = length; - } - - // Constructor that takes malloc - template - explicit Packet(Command cmd, T data) : Packet(cmd, &data, sizeof(T)) { - static_assert(!std::is_pointer_v, - "This constructor should not be used with a pointer"); - } - - // Constructor that takes pointer, uses malloc but with size - template - explicit Packet(Command cmd, T in_data, Header::length_type size) - : data(size) { - static_assert(std::is_pointer_v, - "This constructor should not be used with non pointer"); - header.cmd = cmd; - header.magic = Header::MAGIC_VALUE; - data.assignFrom(in_data, header.data_size = size); - header.data_checksum = crc32_function(data); - } - - static CRC32::result_type crc32_function(const uint8_t* data, - const size_t data_size) { - return CRC32::compute(data, data_size); - } - - static CRC32::result_type crc32_function(const SharedMalloc& data) { - return crc32_function(static_cast(data.get()), data.size()); - } }; using PathStringArray = std::array; @@ -199,7 +178,7 @@ struct alignas(ALIGNMENT) ObserveAllChats { // true/false - Start/Stop observing }; -struct alignas(ALIGNMENT) UploadFileDry { +struct alignas(ALIGNMENT) UploadFileMeta { PathStringArray destfilepath{}; // Destination file name PathStringArray srcfilepath{}; // Source file name (This is not used on the // remote, used if dry=true) @@ -224,21 +203,24 @@ struct alignas(ALIGNMENT) UploadFileDry { } } options; - bool operator==(const UploadFileDry& other) const { + bool operator==(const UploadFileMeta& other) const { return arraycmp(destfilepath, other.destfilepath) && arraycmp(srcfilepath, other.srcfilepath) && arraycmp(sha256_hash, sha256_hash) && options == other.options; } }; -struct alignas(ALIGNMENT) UploadFile : public UploadFileDry { - using Options = UploadFileDry::Options; +struct alignas(ALIGNMENT) UploadFile : public UploadFileMeta { + using Options = UploadFileMeta::Options; alignas(ALIGNMENT) uint8_t buf[]; // Buffer }; -struct alignas(ALIGNMENT) DownloadFile { - PathStringArray filepath{}; // Path to file (in remote) - PathStringArray destfilename{}; // Destination file name +struct alignas(ALIGNMENT) DownloadFileMeta { + PathStringArray filepath{}; // Path to file (in remote) + PathStringArray destfilename{}; // Destination file name +}; + +struct alignas(ALIGNMENT) DownloadFile : DownloadFileMeta { alignas(ALIGNMENT) uint8_t buf[]; // Buffer }; } // namespace data @@ -279,13 +261,14 @@ struct alignas(ALIGNMENT) GenericAck { * { * "result": True|False * "__comment__": "Below two fields are optional" - * "error_type": "TGAPI_EXCEPTION"|"INVALID_ARGUMENT"|"COMMAND_IGNORED"|"RUNTIME_ERROR"|"CLIENT_ERROR", + * "error_type": + * "TGAPI_EXCEPTION"|"INVALID_ARGUMENT"|"COMMAND_IGNORED"|"RUNTIME_ERROR"|"CLIENT_ERROR", * "error_msg": "Error message" * } */ - + struct alignas(ALIGNMENT) UploadFileDryCallback : public GenericAck { - data::UploadFileDry requestdata; + data::UploadFileMeta requestdata; }; } // namespace callback @@ -304,7 +287,7 @@ ASSERT_SIZE(SendFileToChatId, 272); ASSERT_SIZE(ObserveAllChats, 8); ASSERT_SIZE(UploadFile, 552); ASSERT_SIZE(DownloadFile, 512); -ASSERT_SIZE(Packet::Header, 32); +ASSERT_SIZE(Packet::Header, 144); } // namespace TgBotSocket::data namespace TgBotSocket::callback { diff --git a/src/socket/bot/FileHelperNew.cpp b/src/socket/bot/FileHelperNew.cpp index a97496d2..2b3d329a 100644 --- a/src/socket/bot/FileHelperNew.cpp +++ b/src/socket/bot/FileHelperNew.cpp @@ -1,14 +1,11 @@ #include "FileHelperNew.hpp" #include -#include -#ifdef __TGBOT__ #include -#else -#include "../../include/StructF.hpp" -#endif +#include +#include "PacketParser.hpp" #include "TgBotSocket_Export.hpp" bool RealFS::writeFile(const std::filesystem::path& filename, @@ -60,18 +57,19 @@ bool RealFS::exists(const std::filesystem::path& path) { } void RealFS::SHA256(const SharedMalloc& memory, HashContainer& data) { - data.m_data = SHA256::compute(static_cast(memory.get()), memory.size()); + data.m_data = SHA256::compute(static_cast(memory.get()), + memory.size()); } using TgBotSocket::data::DownloadFile; using TgBotSocket::data::UploadFile; -using TgBotSocket::data::UploadFileDry; +using TgBotSocket::data::UploadFileMeta; bool SocketFile2DataHelper::DataToFile_UPLOAD_FILE_DRY( const void* ptr, TgBotSocket::Packet::Header::length_type len) { - const auto* data = static_cast(ptr); - if (len != sizeof(UploadFileDry)) { - LOG(ERROR) << "Invalid UploadFileDry packet size"; + const auto* data = static_cast(ptr); + if (len != sizeof(UploadFileMeta)) { + LOG(ERROR) << "Invalid UploadFileMeta packet size"; return false; } const char* filename = data->destfilepath.data(); @@ -133,7 +131,8 @@ bool SocketFile2DataHelper::DataToFile_DOWNLOAD_FILE( std::optional SocketFile2DataHelper::DataFromFile_UPLOAD_FILE( - const DataFromFileParam& params) { + const DataFromFileParam& params, + const TgBotSocket::Packet::Header::session_token_type& session_token) { const auto _result = vfs->readFile(params.filepath); HashContainer hash{}; @@ -148,7 +147,7 @@ SocketFile2DataHelper::DataFromFile_UPLOAD_FILE( // The front bytes of the buffer is UploadFile, hence cast it auto* uploadFile = static_cast(resultPointer.get()); // Copy destination file name info to the buffer - copyTo(uploadFile->destfilepath, params.destfilepath.string().c_str()); + copyTo(uploadFile->destfilepath, params.destfilepath.string()); // Copy source file data to the buffer memcpy(&uploadFile->buf[0], result.get(), result.size()); // Calculate SHA256 hash @@ -160,14 +159,16 @@ SocketFile2DataHelper::DataFromFile_UPLOAD_FILE( // Set dry run to false uploadFile->options.dry_run = false; - return TgBotSocket::Packet{TgBotSocket::Command::CMD_UPLOAD_FILE, - resultPointer.get(), - result.size() + sizeof(UploadFile)}; + return TgBotSocket::createPacket( + TgBotSocket::Command::CMD_UPLOAD_FILE, resultPointer.get(), + result.size() + sizeof(UploadFile), TgBotSocket::PayloadType::Binary, + session_token); } std::optional SocketFile2DataHelper::DataFromFile_UPLOAD_FILE_DRY( - const DataFromFileParam& params) { + const DataFromFileParam& params, + const TgBotSocket::Packet::Header::session_token_type& session_token) { const auto _result = vfs->readFile(params.filepath); HashContainer hash{}; @@ -177,13 +178,13 @@ SocketFile2DataHelper::DataFromFile_UPLOAD_FILE_DRY( } const auto& result = _result.value(); // Create result packet buffer - auto resultPointer = SharedMalloc(sizeof(UploadFileDry)); + auto resultPointer = SharedMalloc(sizeof(UploadFileMeta)); // The front bytes of the buffer is UploadFile, hence cast it - auto* uploadFile = static_cast(resultPointer.get()); + auto* uploadFile = static_cast(resultPointer.get()); // Copy destination file name info to the buffer - copyTo(uploadFile->destfilepath, params.destfilepath.string().c_str()); + copyTo(uploadFile->destfilepath, params.destfilepath.string()); // Copy source file name to the buffer - copyTo(uploadFile->srcfilepath, params.filepath.string().c_str()); + copyTo(uploadFile->srcfilepath, params.filepath.string()); // Calculate SHA256 hash vfs->SHA256(result, hash); @@ -194,13 +195,16 @@ SocketFile2DataHelper::DataFromFile_UPLOAD_FILE_DRY( // Set dry run to true uploadFile->options.dry_run = true; - return TgBotSocket::Packet{TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, - resultPointer.get(), sizeof(UploadFileDry)}; + return TgBotSocket::createPacket( + TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, resultPointer.get(), + sizeof(UploadFileMeta), TgBotSocket::PayloadType::Binary, + session_token); } std::optional SocketFile2DataHelper::DataFromFile_DOWNLOAD_FILE( - const DataFromFileParam& params) { + const DataFromFileParam& params, + const TgBotSocket::Packet::Header::session_token_type& session_token) { const auto _result = vfs->readFile(params.filepath); HashContainer hash{}; @@ -214,13 +218,14 @@ SocketFile2DataHelper::DataFromFile_DOWNLOAD_FILE( // The front bytes of the buffer is DownloadFile, hence cast it auto* downloadFile = static_cast(resultPointer.get()); // Copy destination file name info to the buffer - copyTo(downloadFile->destfilename, params.destfilepath.string().c_str()); + copyTo(downloadFile->destfilename, params.destfilepath.string()); // Copy source file data to the buffer memcpy(&downloadFile->buf[0], result.get(), result.size()); // Calculate SHA256 hash vfs->SHA256(result, hash); - return TgBotSocket::Packet{TgBotSocket::Command::CMD_DOWNLOAD_FILE_CALLBACK, - resultPointer.get(), - result.size() + sizeof(DownloadFile)}; + return TgBotSocket::createPacket( + TgBotSocket::Command::CMD_DOWNLOAD_FILE_CALLBACK, resultPointer.get(), + result.size() + sizeof(TgBotSocket::data::DownloadFileMeta), + TgBotSocket::PayloadType::Binary, session_token); } \ No newline at end of file diff --git a/src/socket/bot/FileHelperNew.hpp b/src/socket/bot/FileHelperNew.hpp index ff6d2a84..b129639e 100644 --- a/src/socket/bot/FileHelperNew.hpp +++ b/src/socket/bot/FileHelperNew.hpp @@ -1,22 +1,15 @@ #pragma once +#include + +#include #include #include #include #include #include #include - -#ifdef __TGBOT__ -#include #include -#include -#else -#define Socket_API -#define APPLE_INJECT(x) x -#define APPLE_EXPLICIT_INJECT(x) explicit x -#include "../../include/SharedMalloc.hpp" -#endif // Represents a SHA-256 hash struct Socket_API HashContainer { @@ -120,7 +113,8 @@ class Socket_API SocketFile2DataHelper { VFSOperations* vfs; public: - APPLE_EXPLICIT_INJECT(SocketFile2DataHelper(VFSOperations* vfs)) : vfs(vfs) {} + APPLE_EXPLICIT_INJECT(SocketFile2DataHelper(VFSOperations* vfs)) + : vfs(vfs) {} enum class Pass { UPLOAD_FILE_DRY, @@ -149,29 +143,33 @@ class Socket_API SocketFile2DataHelper { template std::optional DataFromFile( - const DataFromFileParam& params) { + const DataFromFileParam& params, + const TgBotSocket::Packet::Header::session_token_type& session_token) { if (P == Pass::UPLOAD_FILE_DRY) { - return DataFromFile_UPLOAD_FILE_DRY(params); + return DataFromFile_UPLOAD_FILE_DRY(params, session_token); } else if (P == Pass::UPLOAD_FILE) { - return DataFromFile_UPLOAD_FILE(params); + return DataFromFile_UPLOAD_FILE(params, session_token); } else if (P == Pass::DOWNLOAD_FILE) { - return DataFromFile_DOWNLOAD_FILE(params); + return DataFromFile_DOWNLOAD_FILE(params, session_token); } return std::nullopt; } private: - bool DataToFile_UPLOAD_FILE_DRY(const void* ptr, - TgBotSocket::Packet::Header::length_type len); + bool DataToFile_UPLOAD_FILE_DRY( + const void* ptr, TgBotSocket::Packet::Header::length_type len); bool DataToFile_UPLOAD_FILE(const void* ptr, TgBotSocket::Packet::Header::length_type len); bool DataToFile_DOWNLOAD_FILE(const void* ptr, TgBotSocket::Packet::Header::length_type len); std::optional DataFromFile_UPLOAD_FILE( - const DataFromFileParam& params); + const DataFromFileParam& params, + const TgBotSocket::Packet::Header::session_token_type& session_token); std::optional DataFromFile_UPLOAD_FILE_DRY( - const DataFromFileParam& params); + const DataFromFileParam& params, + const TgBotSocket::Packet::Header::session_token_type& session_token); std::optional DataFromFile_DOWNLOAD_FILE( - const DataFromFileParam& params); + const DataFromFileParam& params, + const TgBotSocket::Packet::Header::session_token_type& session_token); }; diff --git a/src/socket/bot/PacketParser.cpp b/src/socket/bot/PacketParser.cpp index 9481b6a4..03fce66c 100644 --- a/src/socket/bot/PacketParser.cpp +++ b/src/socket/bot/PacketParser.cpp @@ -1,11 +1,14 @@ #include "PacketParser.hpp" #include +#include +#include #include +#include #include #include -#include +#include template <> struct fmt::formatter : formatter { @@ -27,6 +30,148 @@ struct fmt::formatter : formatter { } }; +namespace { + +SharedMalloc encrypt_payload( + const TgBotSocket::Packet::Header::session_token_type key, + const SharedMalloc& payload, + TgBotSocket::Packet::Header::init_vector_type& iv) { + using Header = TgBotSocket::Packet::Header; + // Generate random IV + RAND_bytes(iv.data(), Header::IV_LENGTH); + + SharedMalloc encrypted(payload.size() + Header::TAG_LENGTH); + int len = 0; + int encrypted_payload_len = 0; + + // Initialize encryption + auto ctx = RAII::create(EVP_CIPHER_CTX_new(), + &EVP_CIPHER_CTX_free); + if (ctx == nullptr) { + LOG(ERROR) << "Error initializing encryption context"; + return {}; + } + + if (EVP_EncryptInit_ex(ctx.get(), EVP_aes_256_gcm(), nullptr, nullptr, + nullptr) == 0) { + LOG(ERROR) << "Error initializing encryption"; + return {}; + } + + if (EVP_EncryptInit_ex(ctx.get(), nullptr, nullptr, + reinterpret_cast(key.data()), + iv.data()) == 0) { + LOG(ERROR) << "Error initializing encryption with key"; + } + + // Encrypt the plaintext + auto* loc = static_cast(encrypted.get()); + if (EVP_EncryptUpdate(ctx.get(), loc, &len, + static_cast(payload.get()), + payload.size()) == 0) { + LOG(ERROR) << "Error encrypting payload"; + return {}; + } + encrypted_payload_len += len; + + // Finalize encryption + if (EVP_EncryptFinal_ex(ctx.get(), loc + encrypted_payload_len, &len) == + 0) { + LOG(ERROR) << "Error finalizing encryption"; + return {}; + } + encrypted_payload_len += len; + + // Get the authentication tag and append it + if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, Header::TAG_LENGTH, + loc + encrypted_payload_len) == 0) { + LOG(ERROR) << "Error getting authentication tag"; + return {}; + } + + encrypted_payload_len += Header::TAG_LENGTH; + encrypted.resize(encrypted_payload_len); + + ctx.reset(); + DLOG(INFO) << "Encrypted payload of size " << encrypted_payload_len + << " bytes using EVP AES_256"; + return encrypted; +} + +SharedMalloc decrypt_payload( + const TgBotSocket::Packet::Header::session_token_type& key, + const SharedMalloc& encrypted, + const TgBotSocket::Packet::Header::init_vector_type& iv) { + constexpr int tag_size = TgBotSocket::Packet::Header::TAG_LENGTH; + + // Ensure the encrypted size is valid + if (encrypted.size() < tag_size) { + LOG(ERROR) << "Encrypted payload size too small to contain tag"; + return {}; + } + + const size_t decrypted_size = encrypted.size() - tag_size; + SharedMalloc decrypted(decrypted_size); + int len = 0; + int decryped_len = 0; + + // Initialize decryption + auto ctx = RAII::create(EVP_CIPHER_CTX_new(), + &EVP_CIPHER_CTX_free); + if (ctx == nullptr) { + LOG(ERROR) << "Failed to create EVP_CIPHER_CTX"; + return {}; + } + + if (EVP_DecryptInit_ex(ctx.get(), EVP_aes_256_gcm(), nullptr, nullptr, + nullptr) == 0) { + LOG(ERROR) << "Failed to initialize decryption cipher"; + return {}; + } + + if (EVP_DecryptInit_ex(ctx.get(), nullptr, nullptr, + reinterpret_cast(key.data()), + iv.data()) == 0) { + LOG(ERROR) << "Failed to set key and IV for decryption"; + return {}; + } + + // Decrypt the ciphertext + if (EVP_DecryptUpdate(ctx.get(), + static_cast(decrypted.get()), &len, + static_cast(encrypted.get()), + decrypted_size) == 0) { + LOG(ERROR) << "Failed during EVP_DecryptUpdate"; + return {}; + } + decryped_len += len; + + // Set the authentication tag + auto* tag_loc = + static_cast(encrypted.get()) + decrypted_size; + if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, tag_size, + tag_loc) == 0) { + LOG(ERROR) << "Failed to set authentication tag"; + return {}; + } + + // Finalize decryption + auto* out_loc = static_cast(decrypted.get()) + decryped_len; + if (EVP_DecryptFinal_ex(ctx.get(), out_loc, &len) <= 0) { + LOG(ERROR) << "Authentication tag mismatch"; + return {}; + } + + decryped_len += len; + decrypted.resize(decryped_len); + + ctx.reset(); + LOG(INFO) << "Decrypted payload successfully, size: " << decryped_len; + return decrypted; +} + +} // namespace + namespace TgBotSocket { std::optional readPacket(const TgBotSocket::Context& context) { @@ -73,21 +218,61 @@ std::optional readPacket(const TgBotSocket::Context& context) { LOG(INFO) << fmt::format("Received Packet{{cmd={}, data_type={}}}", header.cmd, header.data_type); - const size_t newLength = - sizeof(TgBotSocket::Packet::Header) + header.data_size; - TgBotSocket::Packet packet(newLength); + TgBotSocket::Packet packet{ + .header = header, .data = {} // Will be filled in the next step. + }; + packet.header = header; + if (packet.header.data_size == 0) { + // In this case, no need to fetch or verify data + return packet; + } auto data = context.read(header.data_size); if (!data) { LOG(ERROR) << "While reading data, failed"; return std::nullopt; } - if (header.data_checksum != Packet::crc32_function(data.value())) { - LOG(WARNING) << "Checksum mismatch, dropping buffer"; + std::string_view session_token(header.session_token.data(), + header.session_token.size()); + if (session_token.empty()) { + LOG(WARNING) << "No session token provided"; return std::nullopt; } - packet.data = std::move(data.value()); - packet.header = header; + if (packet.header.hmac != + HMAC::compute(static_cast(data->get()), data->size(), + session_token)) { + LOG(ERROR) << "HMAC mismatch"; + return std::nullopt; + } + packet.data = decrypt_payload(header.session_token, data.value(), + packet.header.init_vector); + if (!static_cast(packet.data)) { + LOG(ERROR) << "Decryption failed"; + return std::nullopt; + } + return packet; +} + +Packet Socket_API +createPacket(const Command command, const void* data, + Packet::Header::length_type length, const PayloadType payloadType, + const Packet::Header::session_token_type& sessionToken) { + Packet packet{.header = {}, .data = {}}; + packet.header.cmd = command; + packet.header.magic = Packet::Header::MAGIC_VALUE; + packet.header.data_type = payloadType; + packet.header.session_token = sessionToken; + packet.header.nonce = std::time(nullptr); + if (data != nullptr && length > 0) { + packet.data.resize(length); + packet.data.assignFrom(data, length); + packet.data = encrypt_payload(sessionToken, packet.data, + packet.header.init_vector); + packet.header.data_size = packet.data.size(); + packet.header.hmac = + HMAC::compute(static_cast(packet.data.get()), + packet.header.data_size, sessionToken.data()); + } return packet; } } // namespace TgBotSocket \ No newline at end of file diff --git a/src/socket/bot/PacketParser.hpp b/src/socket/bot/PacketParser.hpp index 98408e92..1949f827 100644 --- a/src/socket/bot/PacketParser.hpp +++ b/src/socket/bot/PacketParser.hpp @@ -1,12 +1,48 @@ #pragma once +#ifdef TgBotSocketParse_JNI_EXPORTS +#define Socket_API +#else #include +#endif #include #include +#include "TgBotSocket_Export.hpp" + namespace TgBotSocket { -std::optional Socket_API readPacket(const TgBotSocket::Context& context); +/** + * @brief Reads a packet from the socket using the provided context. + * + * This function attempts to read a packet from the socket associated with the + * given context. If successful, it returns an optional containing the read + * packet. If no packet is available, it returns an empty optional. + * + * @param context The context to use for reading the packet. + * @return An optional containing the read packet, or an empty optional if no + * packet is available. + */ +std::optional Socket_API +readPacket(const TgBotSocket::Context& context); + +/** + * @brief Creates a packet with the given command and data. + * + * This function creates a packet with the specified command and data. The + * length of the data is determined by the provided length parameter. + * + * @param command The command to set in the packet header. + * @param data A pointer to the data to be included in the packet. + * @param length The length of the data in bytes. + * @param payloadType The type of payload used in the packet. + * @param sessionToken The session token to set in the packet header. + * @return The created packet. + */ +Packet Socket_API +createPacket(const Command command, const void* data, + Packet::Header::length_type length, const PayloadType payloadType, + const Packet::Header::session_token_type& sessionToken); -} \ No newline at end of file +} // namespace TgBotSocket \ No newline at end of file diff --git a/src/socket/bot/SocketDataHandler.cpp b/src/socket/bot/SocketDataHandler.cpp index d47e937b..4da8e5d8 100644 --- a/src/socket/bot/SocketDataHandler.cpp +++ b/src/socket/bot/SocketDataHandler.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -19,8 +20,11 @@ #include #include "FileHelperNew.hpp" +#include "SharedMalloc.hpp" #include "SocketContext.hpp" #include "SocketInterface.hpp" +#include "bot/PacketParser.hpp" +#include "tgbot/tools/StringTools.h" using TgBot::InputFile; namespace fs = std::filesystem; @@ -90,16 +94,18 @@ std::optional parseAndCheck( return root; } -Packet nodeToPacket(const Json::Value& json) { +Packet nodeToPacket(const Command& command, const Json::Value& json, + const Packet::Header::session_token_type& session_token) { std::string result; Json::FastWriter writer; result = writer.write(json); - Packet packet(Command::CMD_GENERIC_ACK, result.c_str(), result.size()); - packet.header.data_type = PayloadType::Json; + auto packet = createPacket(command, result.c_str(), result.size(), + PayloadType::Json, session_token); return packet; } -Packet toJSONPacket(const GenericAck& ack) { +Packet toJSONPacket(const GenericAck& ack, + const Packet::Header::session_token_type& session_token) { Json::Value root; root["result"] = ack.result == AckType::SUCCESS; if (ack.result != AckType::SUCCESS) { @@ -126,7 +132,7 @@ Packet toJSONPacket(const GenericAck& ack) { break; } } - return nodeToPacket(root); + return nodeToPacket(Command::CMD_GENERIC_ACK, root, session_token); } template @@ -450,7 +456,7 @@ GenericAck SocketInterfaceTgBot::handle_UploadFile( UploadFileDryCallback SocketInterfaceTgBot::handle_UploadFileDry( const void* ptr, TgBotSocket::Packet::Header::length_type len) { bool ret = false; - const auto* f = static_cast(ptr); + const auto* f = static_cast(ptr); UploadFileDryCallback callback; callback.requestdata = *f; @@ -466,14 +472,15 @@ UploadFileDryCallback SocketInterfaceTgBot::handle_UploadFileDry( return callback; } -bool SocketInterfaceTgBot::handle_DownloadFile(const TgBotSocket::Context& ctx, - const void* ptr) { +bool SocketInterfaceTgBot::handle_DownloadFile( + const TgBotSocket::Context& ctx, const void* ptr, + const TgBotSocket::Packet::Header::session_token_type& token) { const auto* data = static_cast(ptr); SocketFile2DataHelper::DataFromFileParam params; params.filepath = data->filepath.data(); params.destfilepath = data->destfilename.data(); auto pkt = helper->DataFromFile( - params); + params, token); if (!pkt) { LOG(ERROR) << "Failed to prepare download file packet"; return false; @@ -482,16 +489,17 @@ bool SocketInterfaceTgBot::handle_DownloadFile(const TgBotSocket::Context& ctx, return true; } -bool SocketInterfaceTgBot::handle_GetUptime(const TgBotSocket::Context& ctx, - const void* /*ptr*/) { +bool SocketInterfaceTgBot::handle_GetUptime( + const TgBotSocket::Context& ctx, + const TgBotSocket::Packet::Header::session_token_type& token) { auto now = std::chrono::system_clock::now(); const auto diff = to_secs(now - startTp); GetUptimeCallback callback{}; copyTo(callback.uptime, fmt::format("Uptime: {:%H:%M:%S}", diff).c_str()); LOG(INFO) << "Sending text back: " << std::quoted(callback.uptime.data()); - Packet pkt(Command::CMD_GET_UPTIME_CALLBACK, callback); - ctx.write(pkt); + ctx.write(createPacket(Command::CMD_GET_UPTIME_CALLBACK, &callback, + sizeof(callback), PayloadType::Binary, token)); return true; } @@ -507,6 +515,65 @@ bool CHECK_PACKET_SIZE(Packet& pkt) { return true; } +bool SocketInterfaceTgBot::verifyHeader(const Packet& packet) { + std::string_view their_token(packet.header.session_token.data(), + Packet::Header::SESSION_TOKEN_LENGTH); + if (their_token.empty()) { + LOG(WARNING) << "Received packet with empty session token"; + return false; + } + if (!session_table.contains(their_token.data())) { + LOG(WARNING) << fmt::format( + "Received packet with unknown session token: {}", their_token); + return false; + } + auto& session = session_table[their_token.data()]; + if (std::chrono::system_clock::now() > session.expiry) { + LOG(WARNING) << fmt::format("Session token expired: {}", their_token); + session_table.erase(their_token.data()); + return false; + } + if (session.last_nonce >= packet.header.nonce) { + LOG(WARNING) << "Received packet with outdated nonce, ignore"; + return false; + } + session.last_nonce = packet.header.nonce; + return true; +} + +void SocketInterfaceTgBot::handle_OpenSession(const TgBotSocket::Context& ctx) { + auto key = StringTools::generateRandomString( + TgBotSocket::Packet::Header::SESSION_TOKEN_LENGTH - 1); + Packet::Header::nounce_type last_nounce{}; + auto tp = std::chrono::system_clock::now() + std::chrono::hours(1); + + LOG(INFO) << "Created new session with key: " << key; + session_table.emplace(key, Session(key, last_nounce, tp)); + + Json::Value response; + response["session_token"] = key; + response["expiration_time"] = fmt::format("{:%Y-%m-%d %H:%M:%S}", tp); + + Packet::Header::session_token_type session_token; + copyTo(session_token, key); + + ctx.write( + nodeToPacket(Command::CMD_OPEN_SESSION_ACK, response, session_token)); +} + +void SocketInterfaceTgBot::handle_CloseSession( + const TgBotSocket::Packet::Header::session_token_type& token) { + auto it = session_table.find(token.data()); + if (it != session_table.end()) { + session_table.erase(it); + LOG(INFO) << "Session with key " << token.data() << " closed"; + } else { + LOG(WARNING) + << "Received close session request for unknown session token: " + << token.data(); + } +} + void SocketInterfaceTgBot::handlePacket(const TgBotSocket::Context& ctx, TgBotSocket::Packet pkt) { const void* ptr = pkt.data.get(); @@ -539,20 +606,20 @@ void SocketInterfaceTgBot::handlePacket(const TgBotSocket::Context& ctx, pkt.header.data_type); break; case Command::CMD_GET_UPTIME: - ret = handle_GetUptime(ctx, ptr); + ret = handle_GetUptime(ctx, pkt.header.session_token); break; case Command::CMD_UPLOAD_FILE: ret = handle_UploadFile(ptr, pkt.header.data_size); break; case Command::CMD_UPLOAD_FILE_DRY: - if (CHECK_PACKET_SIZE(pkt)) { + if (CHECK_PACKET_SIZE(pkt)) { ret = handle_UploadFileDry(ptr, pkt.header.data_size); } else { ret = UploadFileDryCallback(invalidPacketAck); } break; case Command::CMD_DOWNLOAD_FILE: - ret = handle_DownloadFile(ctx, ptr); + ret = handle_DownloadFile(ctx, ptr, pkt.header.session_token); break; default: if (CommandHelpers::isClientCommand(pkt.header.cmd)) { @@ -574,8 +641,10 @@ void SocketInterfaceTgBot::handlePacket(const TgBotSocket::Context& ctx, } case Command::CMD_UPLOAD_FILE_DRY: { const auto result = std::get(ret); - Packet ackpkt(Command::CMD_UPLOAD_FILE_DRY_CALLBACK, &result, - sizeof(UploadFileDryCallback)); + auto ackpkt = + createPacket(Command::CMD_UPLOAD_FILE_DRY_CALLBACK, &result, + sizeof(UploadFileDryCallback), PayloadType::Binary, + pkt.header.session_token); LOG(INFO) << "Sending CMD_UPLOAD_FILE_DRY ack: " << std::boolalpha << (result.result == AckType::SUCCESS); ctx.write(ackpkt); @@ -592,12 +661,14 @@ void SocketInterfaceTgBot::handlePacket(const TgBotSocket::Context& ctx, << (result.result == AckType::SUCCESS); switch (pkt.header.data_type) { case PayloadType::Binary: { - Packet ackpkt(Command::CMD_GENERIC_ACK, &result, - sizeof(GenericAck)); + auto ackpkt = createPacket(Command::CMD_GENERIC_ACK, + &result, sizeof(GenericAck), + TgBotSocket::PayloadType::Binary, + pkt.header.session_token); ctx.write(ackpkt); } break; case PayloadType::Json: { - ctx.write(toJSONPacket(result)); + ctx.write(toJSONPacket(result, pkt.header.session_token)); } break; } break; diff --git a/src/socket/bot/SocketInterface.cpp b/src/socket/bot/SocketInterface.cpp index 16475942..90aa77d3 100644 --- a/src/socket/bot/SocketInterface.cpp +++ b/src/socket/bot/SocketInterface.cpp @@ -3,8 +3,11 @@ #include #include #include +#include #include +#include "TgBotSocket_Export.hpp" + SocketInterfaceTgBot::SocketInterfaceTgBot(TgBotSocket::Context* _interface, TgBotApi::Ptr _api, ChatObserver* observer, @@ -19,12 +22,27 @@ SocketInterfaceTgBot::SocketInterfaceTgBot(TgBotSocket::Context* _interface, resource(resource) {} void SocketInterfaceTgBot::runFunction(const std::stop_token& token) { - bool ret = _interface->listen([this](const TgBotSocket::Context& ctx) { - auto pkt = readPacket(ctx); - if (pkt) { - handlePacket(ctx, std::move(pkt.value())); - } - }); + bool ret = + _interface->listen([this, token](const TgBotSocket::Context& ctx) { + while (!token.stop_requested()) { + std::optional pkt; + pkt = readPacket(ctx); + if (!pkt) { + break; + } + if (pkt->header.cmd == TgBotSocket::Command::CMD_OPEN_SESSION) { + handle_OpenSession(ctx); + continue; + } else if (!verifyHeader(*pkt)) { + continue; + } else if (pkt->header.cmd == TgBotSocket::Command::CMD_CLOSE_SESSION) { + handle_CloseSession(pkt->header.session_token); + break; + } else { + handlePacket(ctx, std::move(pkt.value())); + } + } + }); if (!ret) { LOG(ERROR) << "Failed to start listening on socket"; } diff --git a/src/socket/bot/SocketInterface.hpp b/src/socket/bot/SocketInterface.hpp index 8d92de72..ccef9fc7 100644 --- a/src/socket/bot/SocketInterface.hpp +++ b/src/socket/bot/SocketInterface.hpp @@ -36,6 +36,15 @@ struct SocketInterfaceTgBot : ThreadRunner { std::chrono::system_clock::time_point startTp = std::chrono::system_clock::now(); + struct Session { + std::string session_key; // Randomly generated session key + TgBotSocket::Packet::Header::nounce_type + last_nonce; // To prevent replay attacks + std::chrono::system_clock::time_point expiry; // Session expiry + }; + + std::unordered_map session_table; + // Command handlers GenericAck handle_WriteMsgToChatId( const void* ptr, TgBotSocket::Packet::Header::length_type len, @@ -56,6 +65,15 @@ struct SocketInterfaceTgBot : ThreadRunner { const void* ptr, TgBotSocket::Packet::Header::length_type len); // These have their own ack handlers - bool handle_GetUptime(const TgBotSocket::Context& ctx, const void* ptr); - bool handle_DownloadFile(const TgBotSocket::Context& ctx, const void* ptr); + bool handle_GetUptime( + const TgBotSocket::Context& ctx, + const TgBotSocket::Packet::Header::session_token_type& token); + bool handle_DownloadFile( + const TgBotSocket::Context& ctx, const void* ptr, + const TgBotSocket::Packet::Header::session_token_type& token); + + void handle_OpenSession(const TgBotSocket::Context& ctx); + void handle_CloseSession(const TgBotSocket::Packet::Header::session_token_type& token); + + bool verifyHeader(const TgBotSocket::Packet& packet); }; diff --git a/src/socket/hash/CMakeLists.txt b/src/socket/hash/CMakeLists.txt new file mode 100644 index 00000000..73b14400 --- /dev/null +++ b/src/socket/hash/CMakeLists.txt @@ -0,0 +1,6 @@ +add_my_library( + NAME minimalhash + SRCS sha256.cpp hmac.cpp + LIBS OpenSSL::Crypto + STATIC +) \ No newline at end of file diff --git a/src/socket/hash/hmac.cpp b/src/socket/hash/hmac.cpp new file mode 100644 index 00000000..d803458b --- /dev/null +++ b/src/socket/hash/hmac.cpp @@ -0,0 +1,21 @@ +#include "hmac.hpp" + +#include + +HMAC::result_type HMAC::compute(const uint8_t* data, std::size_t length, + const std::string_view key) { + // Buffer to store the HMAC result + result_type hmac_result{}; + unsigned int hmac_length = 0; + + // Compute the HMAC + ::HMAC(EVP_sha256(), key.data(), key.length(), data, length, + hmac_result.data(), &hmac_length); + + if (hmac_length != EVP_MAX_MD_SIZE) { + for (unsigned int i = hmac_length; i < EVP_MAX_MD_SIZE; i++) { + hmac_result[i] = 0; // Pad with zeros if necessary + } + } + return hmac_result; +} \ No newline at end of file diff --git a/src/socket/hash/hmac.hpp b/src/socket/hash/hmac.hpp new file mode 100644 index 00000000..d32ed43a --- /dev/null +++ b/src/socket/hash/hmac.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include + +#include +#include +#include + +class HMAC { + public: + using result_type = std::array; + + static result_type compute(const uint8_t* data, std::size_t length, + const std::string_view key); +}; diff --git a/src/socket/hash/sha256.cpp b/src/socket/hash/sha256.cpp new file mode 100644 index 00000000..d357a6c3 --- /dev/null +++ b/src/socket/hash/sha256.cpp @@ -0,0 +1,7 @@ +#include "sha256.hpp" + +SHA256::result_type SHA256::compute(const uint8_t* data, std::size_t length) { + result_type result{}; + ::SHA256(data, length, result.data()); + return result; +} \ No newline at end of file diff --git a/src/hash/sha256.hpp b/src/socket/hash/sha256.hpp similarity index 63% rename from src/hash/sha256.hpp rename to src/socket/hash/sha256.hpp index 758467cb..0ac7a9d3 100644 --- a/src/hash/sha256.hpp +++ b/src/socket/hash/sha256.hpp @@ -3,11 +3,11 @@ #include #include -#include "sha-2/sha-256.h" +#include class SHA256 { public: - using result_type = std::array; + using result_type = std::array; static result_type compute(const uint8_t* data, std::size_t length); }; \ No newline at end of file diff --git a/tests/SocketDataHandlerTest.cpp b/tests/SocketDataHandlerTest.cpp index 7801516a..5114dc08 100644 --- a/tests/SocketDataHandlerTest.cpp +++ b/tests/SocketDataHandlerTest.cpp @@ -15,6 +15,8 @@ #include #include "SocketContext.hpp" +#include "TgBotSocket_Export.hpp" +#include "bot/PacketParser.hpp" #include "global_handlers/SpamBlock.hpp" #include "mocks/DatabaseBase.hpp" #include "mocks/ResourceProvider.hpp" @@ -123,7 +125,9 @@ class SocketDataHandlerTest : public ::testing::Test { TEST_F(SocketDataHandlerTest, TestCmdGetUptime) { // data Unused for GetUptime - TgBotSocket::Packet pkt(TgBotSocket::Command::CMD_GET_UPTIME, 0); + TgBotSocket::Packet pkt = + TgBotSocket::createPacket(TgBotSocket::Command::CMD_GET_UPTIME, nullptr, + 0, TgBotSocket::PayloadType::Binary, {}); TgBotSocket::callback::GetUptimeCallback callbackData{}; sendAndVerifyHeader(pkt, @@ -210,17 +217,19 @@ TEST_F(SocketDataHandlerTest, TestCmdWriteMsgToChatIdINVALID) { TEST_F(SocketDataHandlerTest, TestCmdUploadFileDryDoesntExist) { auto data = - TgBotSocket::data::UploadFileDry{.destfilepath = {"test"}, - .srcfilepath = {"testsrc"}, - .sha256_hash = {"asdqwdsadsad"}, - .options = { - .overwrite = true, - .hash_ignore = true, - .dry_run = true, - }}; + TgBotSocket::data::UploadFileMeta{.destfilepath = {"test"}, + .srcfilepath = {"testsrc"}, + .sha256_hash = {"asdqwdsadsad"}, + .options = { + .overwrite = true, + .hash_ignore = true, + .dry_run = true, + }}; // Set expectations - TgBotSocket::Packet pkt(TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, data); + TgBotSocket::Packet pkt = TgBotSocket::createPacket( + TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, &data, sizeof(data), + TgBotSocket::PayloadType::Binary, {}); EXPECT_CALL(*_mockVFS, exists(FSP(data.destfilepath.data()))) .WillOnce(Return(false)); @@ -237,19 +246,21 @@ TEST_F(SocketDataHandlerTest, TestCmdUploadFileDryDoesntExist) { TEST_F(SocketDataHandlerTest, TestCmdUploadFileDryExistsHashDoesntMatch) { auto data = - TgBotSocket::data::UploadFileDry{.destfilepath = {"test"}, - .srcfilepath = {"testsrc"}, - .sha256_hash = {"asdqwdsadsad"}, - .options = { - .overwrite = true, - .hash_ignore = false, - .dry_run = true, - }}; + TgBotSocket::data::UploadFileMeta{.destfilepath = {"test"}, + .srcfilepath = {"testsrc"}, + .sha256_hash = {"asdqwdsadsad"}, + .options = { + .overwrite = true, + .hash_ignore = false, + .dry_run = true, + }}; // Prepare file contents const auto fileMem = createFileMem(); // Set expectations - TgBotSocket::Packet pkt(TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, data); + TgBotSocket::Packet pkt = TgBotSocket::createPacket( + TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, &data, sizeof(data), + TgBotSocket::PayloadType::Binary, {}); EXPECT_CALL(*_mockVFS, exists(FSP(data.destfilepath.data()))) .WillOnce(Return(true)); EXPECT_CALL(*_mockVFS, readFile(FSP(data.destfilepath.data()))) @@ -269,19 +280,21 @@ TEST_F(SocketDataHandlerTest, TestCmdUploadFileDryExistsHashDoesntMatch) { TEST_F(SocketDataHandlerTest, TestCmdUploadFileDryExistsOptSaidNo) { auto data = - TgBotSocket::data::UploadFileDry{.destfilepath = {"test"}, - .srcfilepath = {"testsrc"}, - .sha256_hash = {"asdqwdsadsad"}, - .options = { - .overwrite = false, - .hash_ignore = false, - .dry_run = true, - }}; + TgBotSocket::data::UploadFileMeta{.destfilepath = {"test"}, + .srcfilepath = {"testsrc"}, + .sha256_hash = {"asdqwdsadsad"}, + .options = { + .overwrite = false, + .hash_ignore = false, + .dry_run = true, + }}; // Prepare file contents const auto fileMem = createFileMem(); // Set expectations - TgBotSocket::Packet pkt(TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, data); + TgBotSocket::Packet pkt = TgBotSocket::createPacket( + TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, &data, sizeof(data), + TgBotSocket::PayloadType::Binary, {}); EXPECT_CALL(*_mockVFS, exists(std::filesystem::path(data.destfilepath.data()))) .WillOnce(Return(true)); @@ -308,11 +321,13 @@ TEST_F(SocketDataHandlerTest, TestCmdUploadFileOK) { uploadfile->options.hash_ignore = true; uploadfile->options.overwrite = true; uploadfile->options.dry_run = false; - mem.assignTo(filemem.get(), filemem.size(), sizeof(TgBotSocket::data::UploadFile)); + mem.assignTo(filemem.get(), filemem.size(), + sizeof(TgBotSocket::data::UploadFile)); // Set expectations - TgBotSocket::Packet pkt(TgBotSocket::Command::CMD_UPLOAD_FILE, mem.get(), - mem.size()); + TgBotSocket::Packet pkt = TgBotSocket::createPacket( + TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, mem.get(), mem.size(), + TgBotSocket::PayloadType::Binary, {}); EXPECT_CALL(*_mockVFS, writeFile(FSP(uploadfile->destfilepath.data()), _, filemem.size())) .WillOnce(Return(true));