diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..0b1e1e7 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,27 @@ +**/__pycache__ +**/.venv +**/.classpath +**/.dockerignore +**/.env +**/.git +**/.gitignore +**/.project +**/.settings +**/.toolstarget +**/.vs +**/.vscode +**/*.*proj.user +**/*.dbmdl +**/*.jfm +**/bin +**/charts +**/docker-compose* +**/compose* +**/Dockerfile* +**/node_modules +**/npm-debug.log +**/obj +**/secrets.dev.yaml +**/values.dev.yaml +LICENSE +README.md diff --git a/.gitignore b/.gitignore index 89a0f39..5a97a22 100644 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,9 @@ cloudflare.ini # ignore shell configs *.config +# ignore sensitve shell scripts +*.shs + ### OS generated files ### # mac files diff --git a/bot.py b/bot.py index d60f883..39cdbcc 100644 --- a/bot.py +++ b/bot.py @@ -2,35 +2,34 @@ This is the main file for the bot. """ +import asyncio import logging -from os.path import expanduser import discord from discord.ext import commands -from helpers.database.connection import getDbConnection -from helpers.get_file import getFile +from helpers.database.connection import DatabaseConnection +from helpers.get_ssh_key import getPubKey from helpers.logs import Logger from helpers.env import getEnvVar from helpers.terminal_colors import TerminalColors -from helpers.guild.add_guilds import addGuild from helpers.help_command import getHelpCommand from helpers.core_cogs import loadCoreCogs +from event_handlers.guilds import setupGuildEvents +from event_handlers.bot import setupBotEvents + def main(): """ Main entry point for the bot. """ # Set up overall logging - Logger.setup_logging(int(getEnvVar("DISCORD_BOT_LOG_LEVEL"))) + Logger.setup_logging(int(getEnvVar("DISCORD_BOT_LOG_LEVEL", "20"))) # Set up the logger startup_logging = logging.getLogger("discord.bot.startup") - # check if the bot should use a database - use_database = getEnvVar("DISCORD_USE_DATABASE") - # Create a new bot instance bot = commands.Bot( - command_prefix=getEnvVar("DISCORD_BOT_COMMAND_PREFIX"), + command_prefix=getEnvVar("DISCORD_BOT_COMMAND_PREFIX", "."), intents=discord.Intents.all(), description=getEnvVar("DISCORD_BOT_DESCRIPTION"), owner_id=int(getEnvVar("DISCORD_BOT_OWNER_ID")), @@ -38,87 +37,29 @@ def main(): help_command=getHelpCommand(), ) - if use_database == "True": - # Test the database connection - db_conn = getDbConnection() - if db_conn is not None: - startup_logging.info("Connected to the database.") - db_conn.close() - - @bot.event - async def on_guild_join(guild): # pylint: disable=invalid-name - """ - Event: Bot joins a guild - - This event is called when the bot joins a new guild and - adds the guild to the database. - """ - if use_database == "True": - # Add the guild to the database - db_conn = getDbConnection() - if db_conn is not None: - addGuild(db_conn, guild) - db_conn.close() - - @bot.event - async def on_ready(): # pylint: disable=invalid-name - """ - Event: Bot is ready - - This event is called when the bot is ready to be used and - prints information about the bot. - """ - - if bot.user is not None: - startup_logging.info("Logged in as %s", bot.user.name) - - # Print the join URL - startup_logging.info( - "Invite URL: \ -%shttps://discord.com/api/oauth2/authorize?\ -client_id=%s&permissions=8&scope=bot%s", - TerminalColors.GREEN_BOLD, - bot.user.id, - TerminalColors.RESET_COLOR - ) - - # list all servers the bot is connected to - if bot.user is not None: - startup_logging.info( - "%s%s%s is connected to %s%s guilds %s", - TerminalColors.GREEN_BOLD, - bot.user.name, - TerminalColors.RESET_COLOR, - TerminalColors.GREEN_BOLD, - len(bot.guilds), - TerminalColors.RESET_COLOR - ) - - startup_logging.info("Loading core cogs...") - await loadCoreCogs(bot, "core") - - # check if using root or user ssh key if not set default to root - use_user = getEnvVar("DISCORD_USE_USER_SSH") or "False" - - if use_user == "True": - ssh_file = expanduser("~") + "/.ssh/id_ed25519.pub" - else: - ssh_file = "/root/.ssh/id_ed25519.pub" + if getEnvVar("DISCORD_USE_DATABASE", "False") == "True": + database_connection: DatabaseConnection = DatabaseConnection() # Read public ssh key from file and log it startup_logging.info( "Public SSH key: %s%s%s", TerminalColors.GREEN_BOLD, - getFile(ssh_file).strip("\n"), - TerminalColors.RESET_COLOR + getPubKey(), + TerminalColors.RESET_COLOR, ) + # instantiate the bot events + asyncio.run(setupBotEvents(bot)) + # Run the bot bot.run( getEnvVar("DISCORD_BOT_TOKEN"), log_handler=None, ) + # Close the database connection before exiting + database_connection.close_connection() + # Run the main function if __name__ == "__main__": diff --git a/core/admin/admin.py b/core/admin.py similarity index 100% rename from core/admin/admin.py rename to core/admin.py diff --git a/event_handlers/bot.py b/event_handlers/bot.py new file mode 100644 index 0000000..4557ae2 --- /dev/null +++ b/event_handlers/bot.py @@ -0,0 +1,60 @@ +""" +This file contains the operations to setup a new guild +in the database, and is used by the bot when a new guild is added. +""" + +import logging +from discord.ext import commands +from event_handlers.guilds import addMissingGuilds, setupGuildEvents +from helpers.env import getEnvVar +from helpers.terminal_colors import TerminalColors +from helpers.core_cogs import loadCoreCogs + + +logger = logging.getLogger("discord.bot.events") + + +async def setupBotEvents(bot: commands.Bot): + """ + This function sets up the bot events + """ + + @bot.event + async def on_ready(): # pylint: disable=invalid-name + """ + Event: Bot is ready + + This event is called when the bot is ready to be used and + prints information about the bot. + """ + + if bot.user is not None: + + # Print the join URL + logger.info( + "Invite URL: \ +%shttps://discord.com/api/oauth2/authorize?\ +client_id=%s&permissions=8&scope=bot%s", + TerminalColors.GREEN_BOLD, + bot.user.id, + TerminalColors.RESET_COLOR, + ) + + logger.info( + "%s%s%s is connected to %s%s guilds %s", + TerminalColors.GREEN_BOLD, + bot.user.name, + TerminalColors.RESET_COLOR, + TerminalColors.GREEN_BOLD, + len(bot.guilds), + TerminalColors.RESET_COLOR, + ) + + logger.info("Loading core cogs...") + await loadCoreCogs(bot, "core") + + if getEnvVar("DISCORD_USE_DATABASE", "False") == "True": + await setupGuildEvents(bot) + addMissingGuilds(bot) + + logger.info("Bot is ready") diff --git a/event_handlers/guilds.py b/event_handlers/guilds.py new file mode 100644 index 0000000..9344e1a --- /dev/null +++ b/event_handlers/guilds.py @@ -0,0 +1,100 @@ +""" +This file contains the operations to setup a new guild +in the database, and is used by the bot when a new guild is added. +""" + +import logging +from pymongo import MongoClient +from discord import Guild +from discord.ext import commands +from helpers.database.connection import DatabaseConnection +from helpers.env import getEnvVar + +logger = logging.getLogger("discord.guilds") + + +async def setupGuildEvents(bot: commands.Bot): + """ + This function sets up the bot events for the guildis + """ + + # get the database connection + db_connection = DatabaseConnection().get_connection() + + @bot.event + async def on_guild_join(guild: Guild): # pylint: disable=invalid-name + """ + Add the guild to the database when the bot joins a new guild. + """ + if getEnvVar("DISCORD_USE_DATABASE", "False") == "True": + # Add the guild to the database + addGuild(db_connection, guild) + + @bot.event + async def on_guild_remove(guild: Guild): # pylint: disable=invalid-name + """ + Remove the guild from the database when the bot leaves a guild. + """ + if getEnvVar("DISCORD_USE_DATABASE", "False") == "True": + # Remove the guild to the database + removeGuild(db_connection, guild) + + +def addGuild(db_connection: MongoClient, guild: Guild) -> bool: + """ + Add a new guild to the database. + """ + # create the database + guild_collection = db_connection["guilds"].create_collection(str(guild.id)) + + if guild_collection.name == str(guild.id): + logger.info( + "Database %s for guild %s created successfully", guild.id, guild.name + ) + return True + + logger.error( + "You shouldn't see this message, check that guild %s was created successfully", + guild.name, + ) + + return False + + +def removeGuild(db_connection: MongoClient, guild: Guild) -> bool: + """ + Remove a guild from the database. + """ + + # Connect to the guilds database + db_connection["guilds"].drop_collection(str(guild.id)) + + # check if the database was deleted + if str(guild.id) not in db_connection["guilds"].list_collection_names(): + logger.info( + "Database %s for guild %s removed successfully", guild.id, guild.name + ) + return True + + logger.error( + "You shouldn't see this message, check that guild %s was removed successfully", + guild.name, + ) + + return False + + +def addMissingGuilds(bot: commands.Bot): + """ + Add the guilds that the bot is already in to the database. + """ + # get the database connection + db_connection = DatabaseConnection().get_connection() + + # existing guilds in the database + guilds = db_connection["guilds"].list_collection_names() + + # add the guilds to the database + for guild in bot.guilds: + if str(guild.id) not in guilds: + addGuild(db_connection, guild) diff --git a/helpers/core_cogs.py b/helpers/core_cogs.py index 4fad862..d287f83 100644 --- a/helpers/core_cogs.py +++ b/helpers/core_cogs.py @@ -10,9 +10,8 @@ from discord.ext import commands from helpers.logs import TerminalColors -logger = logging.getLogger("discord.core.cog.loader") -def getClassName(filename: str) -> str: +def getClassName(filename: str, logger: logging.Logger) -> str: """ Get the class name from the filename. """ @@ -21,35 +20,74 @@ def getClassName(filename: str) -> str: if line.startswith("class"): return line.split(" ")[1].split("(")[0] - logger.warning("Could not find class name in file %s%s%s", - TerminalColors.RED_BOLD, - filename, - TerminalColors.RESET_COLOR) + logger.warning( + "Could not find class name in file %s%s%s", + TerminalColors.RED_BOLD, + filename, + TerminalColors.RESET_COLOR, + ) sys.exit(1) -async def loadCoreCogs(bot: commands.Bot , directory: str) -> None: +async def loadCoreCogs(bot: commands.Bot, directory: str) -> None: """ Iterate through the commands folder and load all commands. """ - for folder in os.listdir(directory): - try: - # get the class from the module - for files in os.listdir(f"{directory}/{folder}"): - if files.endswith(".py"): - class_name = getClassName(f"{directory}/{folder}/{files}") - module_name = f"{directory}.{folder}.{files[:-3]}" - module = importlib.import_module(module_name) - class_ = getattr(module, class_name) - await bot.add_cog(class_(bot)) - logger.info("Loaded core cog %s%s%s", - TerminalColors.GREEN_BOLD, - module_name, - TerminalColors.RESET_COLOR) - except ImportError as import_error: - logger.warning("Could not load core cog %s%s%s - Exception: %s", - TerminalColors.RED_BOLD, - folder, + + for file in os.listdir(directory): + # if the file is a directory + if os.path.isdir(f"{directory}/{file}"): + try: + # get the class from the module + for sub_file in os.listdir(f"{directory}/{file}"): + logger = logging.getLogger( + f"discord.{directory}.{sub_file}.cog.loader" + ) + if sub_file.endswith(".py"): + class_name = getClassName( + f"{directory}/{file}/{sub_file}", logger + ) + module_name = f"{directory}.{file}.{sub_file[:-3]}" + module = importlib.import_module(module_name) + class_ = getattr(module, class_name) + await bot.add_cog(class_(bot)) + logger.info( + "Loaded core cog %s%s%s", + TerminalColors.GREEN_BOLD, + module_name, TerminalColors.RESET_COLOR, - import_error) - logger.warning(traceback.format_exc()) + ) + except ImportError as import_error: + logger.warning( + "Could not load core cog %s%s%s - Exception: %s", + TerminalColors.RED_BOLD, + sub_file, + TerminalColors.RESET_COLOR, + import_error, + ) + logger.warning(traceback.format_exc()) + + if file.endswith(".py"): + logger = logging.getLogger(f"discord.{directory}.cog.loader") + try: + class_name = getClassName(f"{directory}/{file}", logger) + module_name = f"{directory}.{file[:-3]}" + module = importlib.import_module(module_name) + class_ = getattr(module, class_name) + await bot.add_cog(class_(bot)) + logger.info( + "Loaded core cog %s%s%s", + TerminalColors.GREEN_BOLD, + module_name, + TerminalColors.RESET_COLOR, + ) + except ImportError as import_error: + logger.warning( + "Could not load %s cog %s%s%s - Exception: %s", + directory, + TerminalColors.RED_BOLD, + file, + TerminalColors.RESET_COLOR, + import_error, + ) + logger.warning(traceback.format_exc()) diff --git a/helpers/database/connection.py b/helpers/database/connection.py index 70fd0bd..dbf1cc2 100644 --- a/helpers/database/connection.py +++ b/helpers/database/connection.py @@ -2,42 +2,180 @@ Class for the database configuration for the bot and all cogs. """ -import os import sys import logging +from typing import Any from pymongo import MongoClient +from helpers.env import getEnvVar -def getDbConnection() -> MongoClient: + +class MongoClientOptions: + """ + Class for the configuration of the mongoDB client. """ - Connects and returns a mongoDB Client. - Using either the following environment variables or default values: - - DISCORD_MONGO_DB_HOST_NAME - - DISCORD_MONGO_DB_PORT - - DISCORD_MONGO_DB_DATABASE_NAME + options = {} + + def __init__(self, host: str): + self.options["enabled"] = True + self.options["host"] = host + + # get the authentication method + self.add_option( + "authMechanism", + getEnvVar("DISCORD_MONGO_DB_AUTHENTICATION_METHOD", "SCRAM-SHA-256"), + ) + + # get the port if set + port = getEnvVar("DISCORD_MONGO_DB_PORT", "27017") + if port is not None: + self.add_option("port", port) + + if self.get_option("authMechanism") == "SCRAM-SHA-256": + # get the username and password + self.add_option( + "username", getEnvVar("DISCORD_MONGO_DB_USERNAME", "bot-o-cat") + ) + self.add_option("password", getEnvVar("DISCORD_MONGO_DB_PASSWORD", "None")) + + self.add_option( + "authSource", + getEnvVar("DISCORD_MONGO_DB_AUTHENTICATION_DATABASE", "admin"), + ) + + if self.get_option("authMechanism") == "MONGODB-X509": + + self.add_option("authSource", "$external") + self.add_option("tls", True) + self.add_option( + "tlsAllowInvalidCertificates", + getEnvVar("DISCORD_MONGO_DB_TLS_ALLOW_INVALID_CERTIFICATES", "false"), + ) + crl = getEnvVar("DISCORD_MONGO_DB_CRL_FILE_PATH", "None") + + if crl != "None": + self.add_option("tlsCAFile", crl) + + # get the client certificate + self.add_option( + "tlsCertificateKeyFile", + getEnvVar("DISCORD_MONGO_DB_CERTIFICATE_FILE_PATH"), + ) + + def add_option(self, option: str, value: str | bool | None) -> None: + """ + Add an option to the mongoDB client configuration. + """ + self.options[option] = value + + def return_options(self) -> dict: + """ + Return the options as a dictionary. + """ + return self.options + + def get_option(self, option: str) -> str: + """ + Get an option from the configuration. + """ + return self.options.get(option, None) + + def remove_option(self, option: str) -> None: + """ + Remove an option from the configuration. + """ + self.options.pop(option, None) + + def use_option(self, option: str, default: Any) -> Any: + """ + Get and remove an option from the configuration. + """ + value = self.options.get(option, default) + logging.debug("Using option: %s = %s", option, value) + self.remove_option(option) + return value + + +class DatabaseConnection: + """ + This class contains the database connection for the bot and all cogs. """ - # The certificate file for the database are located: - # - /etc/ssl/bot.pem - # - /etc/ssl/ca.pem + init = False - # get the host name - host = os.getenv("DISCORD_MONGO_DB_HOST_NAME",None) + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(DatabaseConnection, cls).__new__(cls) + return cls.instance - # get the port - port = os.getenv("DISCORD_MONGO_DB_PORT",None) + def __init__(self): + if self.init is False: + self.connection = self._start_database_connection() + self.init = True + logging.info("Database connection established") - if host and port is not None: - # connect to the database - db_conn = MongoClient( - host, - int(port), - tls=True, - tlsCRLFile="/etc/ssl/ca.pem", - tlsCertificateKeyFile="/etc/ssl/bot.pem" - ) + logging.debug("Database connection already established") + + def get_connection(self) -> MongoClient: + """ + Return the database connection. + """ + return self.connection + + def close_connection(self) -> None: + """ + Close the database connection. + """ + self.connection.close() + + def _start_database_connection(self) -> MongoClient: + """ + Connects and returns a mongoDB Client. + + Using either the following environment variables or default values: + - DISCORD_MONGO_DB_HOST_NAME + - DISCORD_MONGO_DB_PORT + - DISCORD_MONGO_DB_DATABASE_NAME + """ + + db_options = MongoClientOptions(getEnvVar("DISCORD_MONGO_DB_HOST_NAME")) + + if db_options.get_option("enabled") is not None: + db_options.remove_option("enabled") + + logging.debug("Connecting to the database with the following options:") + + if db_options.get_option("authMechanism") == "SCRAM-SHA-256": + db_conn = MongoClient( + host=db_options.use_option("host", "localhost"), + port=int(db_options.use_option("port", "27017")), + username=db_options.use_option("username", "bot-o-cat"), + password=db_options.use_option("password", "None"), + authSource=db_options.use_option("authSource", "admin"), + authMechanism=db_options.use_option( + "authMechanism", "SCRAM-SHA-256" + ), + ) - return db_conn + if db_options.get_option("authMechanism") == "MONGODB-X509": + db_conn = MongoClient( + host=db_options.use_option("host", "localhost"), + port=int(db_options.use_option("port", "27017")), + authMechanism=db_options.use_option( + "authMechanism", "MONGODB-X509" + ), + tls=True, + tlsAllowInvalidCertificates=db_options.use_option( + "tlsAllowInvalidCertificates", "false" + ), + tlsCertificateKeyFile=db_options.use_option( + "tlsCertificateKeyFile", None + ), + authSource=db_options.use_option("authSource", "$external"), + ) + if db_conn is not None: + logging.debug("Connected to the database.") + return db_conn - logging.error("Missing environment variables for the database connection.") - sys.exit(1) + logging.critical("Failed to connect to the database.") + sys.exit(1) diff --git a/helpers/database/create.py b/helpers/database/create.py index a1db0d8..97b9ae2 100644 --- a/helpers/database/create.py +++ b/helpers/database/create.py @@ -1,19 +1,17 @@ """ This file contains the create operations for the mongoDB database. """ + import logging from pymongo import MongoClient -from helpers.database.validation import stringValidation logger = logging.getLogger("discord.db.create") -def createDatabase(db_connection: MongoClient ,database_name: str) -> bool: + +def createDatabase(db_connection: MongoClient, database_name: str) -> bool: """ Create a database in the mongoDB. """ - if not stringValidation(database_name): - logger.error("Invalid database name.") - return False # check if the database already exists if database_name in db_connection.list_database_names(): @@ -30,21 +28,18 @@ def createDatabase(db_connection: MongoClient ,database_name: str) -> bool: logger.error( "You shouldn't see this message, check that database %s was created successfully", - database_name) + database_name, + ) return False + def createCollection( - db_connection: MongoClient, - database_name: str, - collection_name: str) -> bool: + db_connection: MongoClient, database_name: str, collection_name: str +) -> bool: """ Create a collection in the mongoDB database. """ - if not stringValidation(database_name) or not stringValidation(collection_name): - logger.error("Invalid database or collection name.") - return False - # create the collection collection = db_connection[database_name].create_collection(collection_name) @@ -55,22 +50,18 @@ def createCollection( logger.error( "You shouldn't see this message, check that collection %s was created successfully", - collection_name) + collection_name, + ) return False + def insertOneDocument( - db_connection: MongoClient, - database_name: str, - collection_name: str, - document: dict) -> bool: + db_connection: MongoClient, database_name: str, collection_name: str, document: dict +) -> bool: """ Create a document in the mongoDB collection. """ - if not stringValidation(database_name) or not stringValidation(collection_name): - logger.error("Invalid database or collection name.") - return False - # check if the collection exists if collection_name not in db_connection[database_name].list_collection_names(): logger.error("Collection %s does not exist", collection_name) @@ -85,22 +76,21 @@ def insertOneDocument( return True logger.error( - "You shouldn't see this message, check that the document was inserted successfully") + "You shouldn't see this message, check that the document was inserted successfully" + ) return False + def insertManyDocuments( - db_connection: MongoClient, - database_name: str, - collection_name: str, - documents: list) -> bool: + db_connection: MongoClient, + database_name: str, + collection_name: str, + documents: list, +) -> bool: """ Insert many documents in the mongoDB collection. """ - if not stringValidation(database_name) or not stringValidation(collection_name): - logger.error("Invalid database or collection name.") - return False - # check if the collection exists if collection_name not in db_connection[database_name].list_collection_names(): logger.error("Collection %s does not exist", collection_name) @@ -115,5 +105,6 @@ def insertManyDocuments( return True logger.error( - "You shouldn't see this message, check that the documents were inserted successfully") + "You shouldn't see this message, check that the documents were inserted successfully" + ) return False diff --git a/helpers/database/delete.py b/helpers/database/delete.py index 74e338a..dceb5df 100644 --- a/helpers/database/delete.py +++ b/helpers/database/delete.py @@ -1,61 +1,52 @@ """ This file contains the delete operations for the database. """ + import logging from typing import Any from pymongo import MongoClient -from validation import validateMultipleStrings, validateInDatabase +from validation import validateInDatabase logger = logging.getLogger("discord.db.delete") + def deleteDatabase(db_connection: MongoClient, database_name: str) -> bool: """ Delete a database in the mongoDB. """ - if not validateMultipleStrings(database_name): - logger.error("Invalid database name.") - return False # check if the database exists - if validateInDatabase( - db_connection, - logger, - database_name=database_name): + if validateInDatabase(db_connection, logger, database_name=database_name): return False # delete the database db_connection.drop_database(database_name) # check if the database was deleted - if validateInDatabase( - db_connection, - logger, - database_name=database_name): + if validateInDatabase(db_connection, logger, database_name=database_name): return True logger.error( "You shouldn't see this message, check that database %s was deleted successfully", - database_name) + database_name, + ) return False + def deleteCollection( - db_connection: MongoClient, - database_name: str, - collection_name: str) -> bool: + db_connection: MongoClient, database_name: str, collection_name: str +) -> bool: """ Delete a collection in the mongoDB database. """ - if validateMultipleStrings(database_name, collection_name): - logger.error("Invalid database or collection name.") - return False - # check if the collection exists if validateInDatabase( db_connection, logger, database_name=database_name, - collection_name=collection_name): + collection_name=collection_name, + ): return False # delete the collection @@ -66,32 +57,30 @@ def deleteCollection( db_connection, logger, database_name=database_name, - collection_name=collection_name): + collection_name=collection_name, + ): return True logger.error( "You shouldn't see this message, check that collection %s was deleted successfully", - collection_name) + collection_name, + ) return False + def deleteOneDocument( - db_connection: MongoClient, - database_name: str, - collection_name: str, - document: Any) -> bool: + db_connection: MongoClient, database_name: str, collection_name: str, document: Any +) -> bool: """ Delete a document in the mongoDB database. """ - - if not validateMultipleStrings(database_name, collection_name): - return False - # check if the collection exists if not validateInDatabase( db_connection, logger, database_name=database_name, - collection_name=collection_name): + collection_name=collection_name, + ): return False # delete the document @@ -104,5 +93,6 @@ def deleteOneDocument( logger.error( "You shouldn't see this message, check that document %s was deleted successfully", - document) + document, + ) return False diff --git a/helpers/database/read.py b/helpers/database/read.py index 7754821..30402d8 100644 --- a/helpers/database/read.py +++ b/helpers/database/read.py @@ -1,27 +1,21 @@ """ This file contains the read operations for the database. """ + import logging from typing import Any from pymongo import MongoClient -from validation import validateMultipleStrings logger = logging.getLogger("discord.db.read") + def fetchOneDocument( - db: MongoClient, - database_name: str, - collection_name: str, - query: Any) -> dict: + db: MongoClient, database_name: str, collection_name: str, query: Any +) -> dict: """ Fetch one document from the collection in the mongoDB database. """ - # validate the database and collection names - if not validateMultipleStrings(database_name, collection_name): - logger.error("Invalid database or collection name.") - return {} - # fetch one document document = db[database_name][collection_name].find_one(query) @@ -34,18 +28,12 @@ def fetchOneDocument( def fetchAllDocuments( - db: MongoClient, - database_name: str, - collection_name: str, - query: Any) -> list | bool: + db: MongoClient, database_name: str, collection_name: str, query: Any +) -> list | bool: """ Fetch all documents from the collection in the mongoDB database. """ - if not validateMultipleStrings(database_name, collection_name): - logger.error("Invalid database or collection name.") - return False - # fetch all documents documents = list(db[database_name][collection_name].find(query)) diff --git a/helpers/database/update.py b/helpers/database/update.py index f1bd21b..714dc12 100644 --- a/helpers/database/update.py +++ b/helpers/database/update.py @@ -1,25 +1,29 @@ """ This file contains the update operations for the database. """ + import logging from typing import Any from pymongo import MongoClient logger = logging.getLogger("discord.db.update") + def updateDocument( - db_connection: MongoClient, - database_name: str, - collection_name: str, - query: Any, - new_values: Any) -> bool: + db_connection: MongoClient, + database_name: str, + collection_name: str, + query: Any, + new_values: Any, +) -> bool: """ Update a document in the collection in the mongoDB database. """ # update the document updated_document = db_connection[database_name][collection_name].update_one( - query, new_values) + query, new_values + ) # check if the document was updated if updated_document.modified_count: diff --git a/helpers/database/validation.py b/helpers/database/validation.py index 1e3914e..ce7fcb1 100644 --- a/helpers/database/validation.py +++ b/helpers/database/validation.py @@ -4,68 +4,16 @@ import logging from pymongo import MongoClient -logger = logging.getLogger("discord.db.validation") - -def stringValidation(operation: str) -> bool: - """ - Validate the operation is alphanumeric. - - Arguments: - operation: The operation to validate. - - Returns: - bool: True if the operation is valid, False otherwise. - """ - - # validate the string is just alphanumeric - if not operation.isalnum(): - logger.error("Value %s is not alphanumeric", operation) - return False - - return True - -def numberValidation(operation: str) -> bool: - """ - Validate the operation is a number. - - Arguments: - operation: The operation to validate. - - Returns: - bool: True if the operation is valid, False otherwise. - """ - - # validate the string is just alphanumeric - if not operation.isnumeric(): - logger.error("Value %s is not a number", operation) - return False - - return True - -def validateMultipleStrings(*args: str) -> bool: - """ - Validate multiple strings. - - Arguments: - db: The database connection. - *args: The strings to validate. - - Returns: - bool: True if all strings are valid, False otherwise. - """ - for arg in args: - if not stringValidation(arg): - return False +logger = logging.getLogger("discord.db.validation") - return True def validateInDatabase( - db: MongoClient, - db_command_loggeer: logging.Logger, - database_name: str | None = None, - collection_name: str | None = None - ) -> bool: + db: MongoClient, + db_command_loggeer: logging.Logger, + database_name: str | None = None, + collection_name: str | None = None, +) -> bool: """ Validate the collection exists in the database. diff --git a/helpers/env.py b/helpers/env.py index 5d18212..18b3bcf 100644 --- a/helpers/env.py +++ b/helpers/env.py @@ -8,7 +8,7 @@ logger = logging.getLogger("discord.env") -def getEnvVar(var_name: str) -> str: +def getEnvVar(var_name: str, _value: str | None = None) -> str: """ Get an environment variable. @@ -16,10 +16,23 @@ def getEnvVar(var_name: str) -> str: var_name: The name of the environment variable to get. Returns: - str | bool: The value of the environment variable if it exists, False otherwise. + str | bool: The value of the environment variable if it exists, + False otherwise. """ try: return os.environ[var_name] except KeyError as key_error: - logger.error("Environment variable %s not found: %s - EXITING", var_name, key_error) + + if _value is not None: + logger.warning( + "Environment variable %s not found: - Using default value: %s", + var_name, + _value + ) + return _value + + logger.error( + "Environment variable %s not found: %s - EXITING", + var_name, + key_error) sys.exit(1) diff --git a/helpers/get_ssh_key.py b/helpers/get_ssh_key.py new file mode 100644 index 0000000..9702137 --- /dev/null +++ b/helpers/get_ssh_key.py @@ -0,0 +1,101 @@ +""" +This file is a helper file to confirm there is an ssh key available. +""" + +import os +import logging +import sys + +from helpers.env import getEnvVar + +logger = logging.getLogger("discord.ssh.key.loader") + + +def getPubKey() -> str: + """ + Get the public ssh key from the file system. + + Returns: + str The contents of the file if it exists + """ + + logger.info("Getting public ssh key") + + root_ssh_dir = getEnvVar("DISCORD_SSH_PATH", "/root/.ssh") + + pub_key = findBestPubKey(root_ssh_dir) + + logger.info("Public ssh key found: %s", os.path.join(root_ssh_dir, pub_key)) + + with open( + os.path.join(root_ssh_dir, pub_key), "r", encoding="utf-8" + ) as open_pub_key: + return open_pub_key.read().strip("\n") + + +def findBestPubKey(key_path: str) -> str: + """ + Get the best key for the bot to use. + + The best key order is + 1. id_ed25519.pub + 2. id_rsa.pub + 3. *.pub + + """ + + private_key = findBestPrivateKey(key_path) + public_key = private_key + ".pub" + + return public_key + + +def findBestPrivateKey(key_path: str) -> str: + """ + Get the best key for the bot to use. + + The best key order is + 1. id_ed25519 + 2. id_rsa + 3. * + + """ + dir_contents = os.listdir(key_path) + matches = ["id_ed25519", "id_rsa"] + no_match = ["config", "known_hosts", "authorized_keys"] + private_key = searchFiles(dir_contents, matches, no_match) + + if private_key is not None: + return private_key + + logger.critical("No ssh keys found in %s", key_path) + sys.exit(1) + + +def searchFiles( + dir_contents: list[str], search_files: list[str], no_match: list[str] +) -> str | None: + """ + Search for the first matching file in the provided directory contents, + and return the file if it exists. + + Arguments: + dir_contents: The directory contents to search + search_files: The files to search for in the order provided + """ + + for file in dir_contents: + if file.endswith(".pub"): + dir_contents.remove(file) + if file in no_match: + dir_contents.remove(file) + + if len(dir_contents) == 1: + return dir_contents[0] + + for file in search_files: + if file in dir_contents: + return file + + logger.debug("No file found in %s", dir_contents) + return None diff --git a/helpers/guild/add_guilds.py b/helpers/guild/add_guilds.py deleted file mode 100644 index c1d1502..0000000 --- a/helpers/guild/add_guilds.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -This file contains the operations to setup a new guild -in the database, and is used by the bot when a new guild is added. -""" - -import logging -from pymongo import MongoClient -from discord import Guild -from helpers.database.create import createDatabase - -logger = logging.getLogger("discord.guilds.add") - -def addGuild( - db_connection: MongoClient, - guild: Guild) -> bool: - """ - Add a new guild to the database. - """ - - database_name = str(guild.id) + "_db" - - # create the database - guild_created = createDatabase(db_connection, database_name) - - if guild_created: - logger.info("Database %s for guild %s added successfully", - database_name, - guild.name) - return True - - logger.error( - "You shouldn't see this message, check that guild %s was added successfully", - guild.name) - - return False diff --git a/scripts/fish/set_env.fish b/scripts/fish/set_env.fish new file mode 100644 index 0000000..2361d7a --- /dev/null +++ b/scripts/fish/set_env.fish @@ -0,0 +1,3 @@ +#!/usr/bin/env fish. + +set -Ux $argv[1] $argv[2..] diff --git a/scripts/set_env.sh b/scripts/set_env.sh new file mode 100755 index 0000000..70255ed --- /dev/null +++ b/scripts/set_env.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +function check_fish(){ + if [ -n $FISH_VERSION]; then + echo "Fish shell detected. Using fish syntax." + USE_FISH="True" + else + echo "Bash shell detected. Using bash syntax." + USE_BASH="True" + fi +} + +# if shell is fish, use fish syntax +function check_env_variable() { + + if [ -n "$USE_FISH" ]; then + fish scripts/fish/set_env.fish $1 $2 + fi + if [ -n "$USE_BASH" ]; then + if [ -z "${!1}" ]; then + export $1=$2 + fi + + fi +} + +check_fish +# Get the variables from the config file +source variables.config + +# iterate through all variables with the prefix DISCORD_ +for var in $(compgen -A variable | grep DISCORD_); do + check_env_variable $var ${!var} +done