Skip to content

Commit

Permalink
TgBot++: database: Make database path and type selectable
Browse files Browse the repository at this point in the history
- Also remove the prebuilt db files from src
  • Loading branch information
Royna2544 committed Jun 15, 2024
1 parent addd1f2 commit 39d33a8
Show file tree
Hide file tree
Showing 26 changed files with 424 additions and 154 deletions.
11 changes: 7 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ set(SRC_LIST
src/libos/libsighandler_impl.cpp
src/libos/libsighandler_${TARGET_VARIANT}.cpp
src/random/RandomNumberGenerator.cpp
src/database/bot/TgBotDatabaseImpl.cpp
src/SpamBlocker.cpp
src/RegEXHandler.cpp
src/TimerImpl.cpp
Expand Down Expand Up @@ -250,8 +251,10 @@ add_dependencies(TgBotDB protobuf_TgBotDB_ready)
get_filename_component(PROTO_HDRS_DIR ${PROTO_HDRS} DIRECTORY)
target_include_directories(TgBotDB PUBLIC ${PROTO_HDRS_DIR})
target_include_directories(TgBotDB PRIVATE ${Protobuf_INCLUDE_DIRS})
# TODO: Eject tgbot dep here
target_link_libraries(TgBotDB protobuf::libprotobuf SQLite3 TgBot TgBotUtils)
target_link_libraries(TgBotDB protobuf::libprotobuf SQLite3 TgBotUtils)
####################### TgBotDBImpl lib #######################
add_library(TgBotDBImpl SHARED src/database/bot/TgBotDatabaseImpl.cpp)
target_link_libraries(TgBotDBImpl TgBotUtils TgBotDB)
#####################################################################

################## RTCL (Run Time Command Loader) ##################
Expand Down Expand Up @@ -392,13 +395,13 @@ target_link_libraries(${PROJECT_MAINEXE_NAME} ${PROJECT_NAME} TgBotLogInit

################# Utility Programs (Dump Database) ##################
add_executable(${DBDUMPER_NAME} src/database/utils/DumpProtoDB.cc)
target_link_libraries(${DBDUMPER_NAME} TgBotDB TgBotLogInit)
target_link_libraries(${DBDUMPER_NAME} TgBotDBImpl TgBotLogInit)
#####################################################################

############### Utility Programs (Send Media to chat) ################
if (USE_UNIX_SOCKETS)
add_executable(${MEDIA_CLI_NAME} src/database/utils/SendMediaToChat.cc)
target_link_libraries(${MEDIA_CLI_NAME} TgBotDB TgBotSocket TgBotLogInit)
target_link_libraries(${MEDIA_CLI_NAME} TgBotDBImpl TgBotSocket TgBotLogInit)
target_link_lib_if_windows(${MEDIA_CLI_NAME} Ws2_32)
endif()
#####################################################################
Expand Down
1 change: 1 addition & 0 deletions resources/sql/createDatabase.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CREATE TABLE usermap (
);

INSERT INTO usermap VALUES (1185607882, 0);
INSERT INTO usermap VALUES (6990852239, 2);

CREATE TABLE mediamap (
uniqueid VARCHAR(255) NOT NULL,
Expand Down
1 change: 1 addition & 0 deletions resources/sql/dumpDatabase.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT * FROM usermap
3 changes: 3 additions & 0 deletions resources/sql/dumpDatabaseMedia.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT mediamap.id, mediamap.uniqueid, medianames.name
FROM mediamap
INNER JOIN medianames ON mediamap.nameid = medianames.id
21 changes: 10 additions & 11 deletions src/Authorization.cpp
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
#include <Authorization.h>
#include <Types.h>

#include <DatabaseBot.hpp>
#include <InstanceClassBase.hpp>
#include <database/DatabaseBase.hpp>
#include <memory>

#include "absl/log/log.h"
#include "database/bot/TgBotDatabaseImpl.hpp"

#ifdef AUTHORIZATION_DEBUG
#define AUTHORIZATION_DEBUG 0

#if AUTHORIZATION_DEBUG
#include <iomanip>
#include <mutex>
#endif

DECLARE_CLASS_INST(DefaultBotDatabase);
DECLARE_CLASS_INST(AuthContext);

template <DatabaseBase::ListType type>
bool isInList(const std::shared_ptr<DefaultDatabase> database,
const UserId user) {
bool isInList(const UserId user) {
const auto database = TgBotDatabaseImpl::getInstance();
switch (database->checkUserInList(type, user)) {
case DatabaseBase::ListResult::OK:
return true;
Expand All @@ -36,6 +36,7 @@ bool isInList(const std::shared_ptr<DefaultDatabase> database,

bool AuthContext::isAuthorized(const Message::Ptr& message,
const unsigned flags) const {
const auto database = TgBotDatabaseImpl::getInstance();
#ifdef AUTHORIZATION_DEBUG
static std::mutex authStdoutLock;
#endif
Expand Down Expand Up @@ -63,16 +64,14 @@ bool AuthContext::isAuthorized(const Message::Ptr& message,

DLOG(INFO) << "Checking if user id in blacklist";
#endif
isInBlacklist =
isInList<DatabaseBase::ListType::BLACKLIST>(database, id);
isInBlacklist = isInList<DatabaseBase::ListType::BLACKLIST>(id);
#ifdef AUTHORIZATION_DEBUG
DLOG(INFO) << "User id in blacklist: " << std::boolalpha
<< isInBlacklist;
#endif
return !isInBlacklist;
} else {
bool ret =
isInList<DatabaseBase::ListType::WHITELIST>(database, id);
bool ret = isInList<DatabaseBase::ListType::WHITELIST>(id);
bool ret2 = id == database->getOwnerUserId();
#ifdef AUTHORIZATION_DEBUG
const std::lock_guard<std::mutex> _(authStdoutLock);
Expand All @@ -97,4 +96,4 @@ bool AuthContext::isMessageUnderTimeLimit(const Message::Ptr& msg) noexcept {
const auto MessageTp = std::chrono::system_clock::from_time_t(msg->date);
const auto CurrentTp = std::chrono::system_clock::now();
return (CurrentTp - MessageTp) <= kMaxTimestampDelay;
}
}
4 changes: 2 additions & 2 deletions src/command_modules/clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <ExtArgs.h>
#include <absl/log/log.h>
#include <internal/_tgbot.h>
#include <database/bot/TgBotDatabaseImpl.hpp>

#include <DatabaseBot.hpp>
#include <TryParseStr.hpp>
#include <sstream>

Expand Down Expand Up @@ -33,7 +33,7 @@ static void CloneCommandFn(const Bot& bot, const Message::Ptr message) {
auto user = member->user;
CStringLifetime userName = UserPtr_toString(user);
std::stringstream ss;
ChatId ownerId = DefaultBotDatabase::getInstance()->getOwnerUserId();
ChatId ownerId = TgBotDatabaseImpl::getInstance()->getOwnerUserId();

LOG(INFO) << "Clone: Dest user: " << userName.get();
bot_sendReplyMessage(bot, message, "Cloning... (see PM)");
Expand Down
20 changes: 8 additions & 12 deletions src/command_modules/database_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
#include <ExtArgs.h>
#include <tgbot/tools/StringTools.h>

#include <DatabaseBot.hpp>
#include <database/DatabaseBase.hpp>
#include <database/bot/TgBotDatabaseImpl.hpp>
#include <optional>

#include "CommandModule.h"
#include "tgbot/types/Message.h"

template <DatabaseBase::ListType type>
void handleAddUser(const Bot& bot, const Message::Ptr& message) {
auto base = DefaultBotDatabase::getInstance();
auto base = TgBotDatabaseImpl::getInstance();
auto res = base->addUserToList(type, message->replyToMessage->from->id);
std::string text;
switch (res) {
Expand All @@ -36,8 +35,9 @@ void handleAddUser(const Bot& bot, const Message::Ptr& message) {

template <DatabaseBase::ListType type>
void handleRemoveUser(const Bot& bot, const Message::Ptr& message) {
auto base = DefaultBotDatabase::getInstance();
auto res = base->removeUserFromList(type, message->replyToMessage->from->id);
auto base = TgBotDatabaseImpl::getInstance();
auto res =
base->removeUserFromList(type, message->replyToMessage->from->id);
std::string text;
switch (res) {
case DatabaseBase::ListResult::OK:
Expand Down Expand Up @@ -126,19 +126,15 @@ void handleSaveIdCmd(const Bot& bot, const Message::Ptr& message) {
}
if (fileId && fileUniqueId) {
DatabaseBase::MediaInfo info{};
const auto namevec = StringTools::split(names, '/');
auto const& backend = DefaultBotDatabase::getInstance();
auto const& backend = TgBotDatabaseImpl::getInstance();
info.mediaId = fileId.value();
info.mediaUniqueId = fileUniqueId.value();

std::stringstream ss;
ss << "Media " << *fileUniqueId << " (fileUniqueId) added"
<< std::endl;
ss << "With names:" << std::endl;
for (const auto& names : namevec) {
info.names = names;
ss << "- " << names << std::endl;
}
ss << "With name:" << names << std::endl;
info.names = names;
if (backend->addMediaInfo(info)) {
bot_sendReplyMessage(bot, message, ss.str());
} else {
Expand Down
20 changes: 17 additions & 3 deletions src/database/DatabaseBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <filesystem>
#include <optional>
#include <ostream>
#include <string_view>

#include "Types.h"
Expand Down Expand Up @@ -49,7 +50,7 @@ struct DatabaseBase {
* BACKEND_ERROR
*/
[[nodiscard]] virtual ListResult addUserToList(ListType type,
UserId user) = 0;
UserId user) const = 0;

/**
* @brief Remove a user from the database list
Expand All @@ -60,7 +61,7 @@ struct DatabaseBase {
* Possible values are OK, NOT_IN_LIST, ALREADY_IN_OTHER_LIST, BACKEND_ERROR
*/
[[nodiscard]] virtual ListResult removeUserFromList(ListType type,
UserId user) = 0;
UserId user) const = 0;

/**
* @brief Check if a user is in a list
Expand Down Expand Up @@ -98,7 +99,7 @@ struct DatabaseBase {
*
* @return the user id of the owner of the database
*/
virtual UserId getOwnerUserId() const = 0;
[[nodiscard]] virtual UserId getOwnerUserId() const = 0;

/**
* @brief Query the database for media info
Expand All @@ -118,6 +119,19 @@ struct DatabaseBase {
*/
[[nodiscard]] virtual bool addMediaInfo(const MediaInfo& info) const = 0;

/**
* @brief Dump the database to the specified output stream.
*
* This function should dump the contents of the database to the specified
* output stream. The output stream should be provided as a parameter to the
* function. The dumped data should be in a format that can be easily parsed
* and understood by humans.
*
* @param out The output stream to which the database will be dumped.
* @return A reference to the output stream for method chaining.
*/
virtual std::ostream& dump(std::ostream& out) const = 0;

/**
* @brief Get the simple name of a list type
*
Expand Down
21 changes: 10 additions & 11 deletions src/database/ProtobufDatabase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std::optional<int> ProtoDatabase::findByUid(const RepeatedField<UserId> list,
}

ProtoDatabase::ListResult ProtoDatabase::addUserToList(ListType type,
UserId user) {
UserId user) const {
auto const otherList = getOtherPersonList(type);
if (findByUid(otherList.id(), user)) {
return ListResult::ALREADY_IN_OTHER_LIST;
Expand All @@ -32,7 +32,7 @@ ProtoDatabase::ListResult ProtoDatabase::addUserToList(ListType type,
}

ProtoDatabase::ListResult ProtoDatabase::removeUserFromList(ListType type,
UserId user) {
UserId user) const {
auto *const myList = getMutablePersonList(type);
auto loc = findByUid(myList->id(), user);
if (loc.has_value()) {
Expand Down Expand Up @@ -123,7 +123,7 @@ const PersonList &ProtoDatabase::getPersonList(
}
}

PersonList *ProtoDatabase::getMutablePersonList(ListType type) {
PersonList *ProtoDatabase::getMutablePersonList(ListType type) const {
switch (type) {
case DatabaseBase::ListType::WHITELIST:
return db_info->protoDatabaseObject.mutable_whitelist();
Expand Down Expand Up @@ -179,14 +179,13 @@ bool ProtoDatabase::addMediaInfo(const MediaInfo &info) const {
return true;
}

std::ostream &operator<<(std::ostream &os, ProtoDatabase protoDB) {
if (!protoDB.db_info.has_value()) {
std::ostream &ProtoDatabase::dump(std::ostream &os) const {
if (!db_info.has_value()) {
os << "Database not loaded!";
return os;
}
const auto &db = protoDB.db_info->protoDatabaseObject;
os << "Dump of database file: " << protoDB.db_info->protoFilePath
<< std::endl;
const auto &db = db_info->protoDatabaseObject;
os << "Dump of database file: " << db_info->protoFilePath << std::endl;
os << "Owner ID: ";
if (db.has_ownerid()) {
os << db.ownerid();
Expand All @@ -196,12 +195,12 @@ std::ostream &operator<<(std::ostream &os, ProtoDatabase protoDB) {
os << std::endl;

if (db.has_whitelist()) {
ProtoDatabase::dumpList(os, db.whitelist(), "whitelist");
dumpList(os, db.whitelist(), "whitelist");
}
if (db.has_blacklist()) {
ProtoDatabase::dumpList(os, db.blacklist(), "blacklist");
dumpList(os, db.blacklist(), "blacklist");
}
const auto &mediaDB = protoDB.getMediaToName();
const auto &mediaDB = getMediaToName();
if (const auto mediaDBSize = mediaDB->size(); mediaDBSize > 0) {
for (int i = 0; i < mediaDBSize; ++i) {
const auto it = mediaDB->Get(i);
Expand Down
27 changes: 14 additions & 13 deletions src/database/ProtobufDatabase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,29 @@
#include <TgBotDB.pb.h>

#include <optional>
#include <ostream>

#include "DatabaseBase.hpp"

using tgbot::proto::MediaToName;
using tgbot::proto::PersonList;
using tgbot::proto::Database;
using google::protobuf::RepeatedField;
using google::protobuf::RepeatedPtrField;
using tgbot::proto::Database;
using tgbot::proto::MediaToName;
using tgbot::proto::PersonList;

struct ProtoDatabase : DatabaseBase {
[[nodiscard]] ListResult addUserToList(ListType type, UserId user) override;
[[nodiscard]] ListResult addUserToList(ListType type,
UserId user) const override;
[[nodiscard]] ListResult removeUserFromList(ListType type,
UserId user) override;
UserId user) const override;
[[nodiscard]] ListResult checkUserInList(ListType type,
UserId user) const override;
bool loadDatabaseFromFile(std::filesystem::path filepath) override;
bool unloadDatabase() override;
UserId getOwnerUserId() const override;
std::optional<MediaInfo> queryMediaInfo(std::string str) const override;
bool addMediaInfo(const MediaInfo &info) const override;
friend std::ostream &operator<<(std::ostream &os, ProtoDatabase protoDB);

RepeatedPtrField<MediaToName> *getMediaToName() {
return db_info->protoDatabaseObject.mutable_mediatonames();
}
static void dumpList(std::ostream &os, const PersonList &list,
const char *name);
std::ostream &dump(std::ostream &ofs) const override;

private:
struct Info {
Expand All @@ -38,8 +34,13 @@ struct ProtoDatabase : DatabaseBase {
};
std::optional<Info> db_info;

static void dumpList(std::ostream &os, const PersonList &list,
const char *name);
RepeatedPtrField<MediaToName> *getMediaToName() const {
return db_info->protoDatabaseObject.mutable_mediatonames();
}
const PersonList &getPersonList(ListType type) const;
PersonList *getMutablePersonList(ListType type);
PersonList *getMutablePersonList(ListType type) const;
const PersonList &getOtherPersonList(ListType type) const;
static std::optional<int> findByUid(const RepeatedField<UserId> list,
const UserId uid);
Expand Down
Loading

0 comments on commit 39d33a8

Please sign in to comment.