diff --git a/src/socket/bot/PacketParser.cpp b/src/socket/bot/PacketParser.cpp index 03fce66c..40012d15 100644 --- a/src/socket/bot/PacketParser.cpp +++ b/src/socket/bot/PacketParser.cpp @@ -232,25 +232,37 @@ std::optional readPacket(const TgBotSocket::Context& context) { LOG(ERROR) << "While reading data, failed"; return std::nullopt; } + packet.data = *data; + if (!decryptPacket(packet)) { + return std::nullopt; + } + return packet; +} + +bool Socket_API decryptPacket(TgBotSocket::Packet& packet) { + auto& header = packet.header; + auto& data = packet.data; + std::string_view session_token(header.session_token.data(), header.session_token.size()); - if (session_token.empty()) { - LOG(WARNING) << "No session token provided"; - return std::nullopt; + if (header.session_token == + TgBotSocket::Packet::Header::session_token_type{}) { + LOG(WARNING) << "No session token provided, decryption will be skipped."; + return true; } if (packet.header.hmac != - HMAC::compute(static_cast(data->get()), data->size(), + HMAC::compute(static_cast(data.get()), data.size(), session_token)) { LOG(ERROR) << "HMAC mismatch"; - return std::nullopt; + return false; } - packet.data = decrypt_payload(header.session_token, data.value(), - packet.header.init_vector); + packet.data = + decrypt_payload(header.session_token, data, packet.header.init_vector); if (!static_cast(packet.data)) { LOG(ERROR) << "Decryption failed"; - return std::nullopt; + return false; } - return packet; + return true; } Packet Socket_API @@ -266,12 +278,18 @@ createPacket(const Command command, const void* data, if (data != nullptr && length > 0) { packet.data.resize(length); packet.data.assignFrom(data, length); - packet.data = encrypt_payload(sessionToken, packet.data, - packet.header.init_vector); + + if (sessionToken != Packet::Header::session_token_type{}) { + packet.data = encrypt_payload(sessionToken, packet.data, + packet.header.init_vector); + packet.header.hmac = + HMAC::compute(static_cast(packet.data.get()), + packet.header.data_size, sessionToken.data()); + } else { + LOG(WARNING) + << "No session token provided, encryption will be skipped"; + } packet.header.data_size = packet.data.size(); - packet.header.hmac = - HMAC::compute(static_cast(packet.data.get()), - packet.header.data_size, sessionToken.data()); } return packet; } diff --git a/src/socket/bot/PacketParser.hpp b/src/socket/bot/PacketParser.hpp index 1949f827..4e1037b7 100644 --- a/src/socket/bot/PacketParser.hpp +++ b/src/socket/bot/PacketParser.hpp @@ -27,6 +27,18 @@ namespace TgBotSocket { std::optional Socket_API readPacket(const TgBotSocket::Context& context); + +/** + * @brief Decrypts a packet using the provided context. + * + * This function attempts to decrypt the given packet using the provided context. + * If successful, it returns `true`. If decryption fails, it returns `false`. + * + * @param packet The packet to decrypt + * @return `true` if the packet was successfully decrypted; otherwise, `false`. + */ +bool Socket_API decryptPacket(TgBotSocket::Packet& packet); + /** * @brief Creates a packet with the given command and data. * diff --git a/src/third-party/tgbot-cpp b/src/third-party/tgbot-cpp index 6f519c7f..35cc4ad4 160000 --- a/src/third-party/tgbot-cpp +++ b/src/third-party/tgbot-cpp @@ -1 +1 @@ -Subproject commit 6f519c7faf3ea7be26b29e58ccb6c682e4dc208f +Subproject commit 35cc4ad40090a373a56bb131d659ab41bbea46b4 diff --git a/tests/SocketDataHandlerTest.cpp b/tests/SocketDataHandlerTest.cpp index 5114dc08..acfaa9ae 100644 --- a/tests/SocketDataHandlerTest.cpp +++ b/tests/SocketDataHandlerTest.cpp @@ -75,6 +75,7 @@ class SocketDataHandlerTest : public ::testing::Test { SharedMalloc packetData; TgBotSocket::Packet::Header recv_header; + EXPECT_TRUE(TgBotSocket::decryptPacket(pkt)); EXPECT_CALL(*_mockImpl, write(_)) .WillOnce(DoAll(SaveArg<0>(&packetData), Return(true))); mockInterface->handlePacket(*_mockImpl, std::move(pkt)); @@ -313,7 +314,8 @@ TEST_F(SocketDataHandlerTest, TestCmdUploadFileDryExistsOptSaidNo) { TEST_F(SocketDataHandlerTest, TestCmdUploadFileOK) { // Prepare file contents const auto filemem = createFileMem(); - SharedMalloc mem(sizeof(TgBotSocket::data::UploadFile) + filemem.size()); + SharedMalloc mem(sizeof(TgBotSocket::data::UploadFileMeta) + + filemem.size()); auto* uploadfile = static_cast(mem.get()); uploadfile->srcfilepath = {"sourcefile"}; uploadfile->destfilepath = {"destinationfile"}; @@ -322,11 +324,11 @@ TEST_F(SocketDataHandlerTest, TestCmdUploadFileOK) { uploadfile->options.overwrite = true; uploadfile->options.dry_run = false; mem.assignTo(filemem.get(), filemem.size(), - sizeof(TgBotSocket::data::UploadFile)); + sizeof(TgBotSocket::data::UploadFileMeta)); // Set expectations TgBotSocket::Packet pkt = TgBotSocket::createPacket( - TgBotSocket::Command::CMD_UPLOAD_FILE_DRY, mem.get(), mem.size(), + TgBotSocket::Command::CMD_UPLOAD_FILE, mem.get(), mem.size(), TgBotSocket::PayloadType::Binary, {}); EXPECT_CALL(*_mockVFS, writeFile(FSP(uploadfile->destfilepath.data()), _, filemem.size()))