From 36aea0e1b5a701013326983353465919414d9a53 Mon Sep 17 00:00:00 2001 From: Phoenix Pereira <47909638+phoenixpereira@users.noreply.github.com> Date: Wed, 25 Sep 2024 20:42:42 +0930 Subject: [PATCH] feat(chat): Add Rate Limit (#61) --- .example.env | 2 + .github/workflows/production.yml | 4 ++ src/commands/gemini.py | 78 +++++++++++++++++++++++++++++++- src/main.py | 1 + 4 files changed, 83 insertions(+), 2 deletions(-) diff --git a/.example.env b/.example.env index 6f4b9e6..39dcef4 100644 --- a/.example.env +++ b/.example.env @@ -4,3 +4,5 @@ SKULLBOARD_CHANNEL_ID=SKULLBOARD_CHANNEL_ID REQUIRED_REACTIONS=5 TENOR_API_KEY="TENOR_API_KEY" GEMINI_API_KEY="GEMINI_API_KEY" +REQUESTS_PER_MINUTE=3 +LIMIT_WINDOW=60 diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml index 017c74e..42cbc36 100644 --- a/.github/workflows/production.yml +++ b/.github/workflows/production.yml @@ -82,6 +82,8 @@ jobs: REQUIRED_REACTIONS: ${{ secrets.REQUIRED_REACTIONS }} TENOR_API_KEY: ${{ secrets.TENOR_API_KEY }} GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + REQUESTS_PER_MINUTE: ${{ secrets.REQUESTS_PER_MINUTE }} + LIMIT_WINDOW: ${{ secrets.LIMIT_WINDOW }} run: | echo "$KEY" > private_key && chmod 600 private_key ssh -v -o StrictHostKeyChecking=no -i private_key ${USER}@${HOSTNAME} ' @@ -94,6 +96,8 @@ jobs: echo REQUIRED_REACTIONS=${{ secrets.REQUIRED_REACTIONS }} >> .env echo TENOR_API_KEY=${{ secrets.TENOR_API_KEY }} >> .env echo GEMINI_API_KEY=${{ secrets.GEMINI_API_KEY }} >> .env + echo REQUESTS_PER_MINUTE=${{ secrets.REQUESTS_PER_MINUTE }} >> .env + echo LIMIT_WINDOW=${{ secrets.LIMIT_WINDOW }} >> .env docker load -i duckbot.tar.gz docker compose up -d ' diff --git a/src/commands/gemini.py b/src/commands/gemini.py index d3fe3d0..3cbf4e7 100644 --- a/src/commands/gemini.py +++ b/src/commands/gemini.py @@ -5,13 +5,21 @@ import os.path import re import tempfile +import time +from collections import defaultdict +import random +import requests from discord import Embed from google.generativeai.types import HarmCategory, HarmBlockThreshold, File import google.generativeai as genai +from dotenv import load_dotenv from constants.colours import LIGHT_YELLOW +# Load environment variables from .env file +load_dotenv() + SAFETY_SETTINGS = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, @@ -51,9 +59,15 @@ class Errors(IntEnum): class GeminiBot: + REQUESTS_PER_MINUTE = int(os.environ["REQUESTS_PER_MINUTE"]) + LIMIT_WINDOW = int(os.environ["LIMIT_WINDOW"]) + def __init__(self, model_name, data_csv_path, bot, api_key): genai.configure(api_key=api_key) + # Dictionary to track users and their request timestamps + self.user_requests = defaultdict(list) + system_instruction = ( "You are DuckBot, the official discord bot for the Computer Science Club of the University of Adelaide. " "Your main purpose is to answer CS questions and FAQs by users. " @@ -102,6 +116,49 @@ def __init__(self, model_name, data_csv_path, bot, api_key): # Gemini API provides a chat option to maintain a conversation self.chat = self.model.start_chat() + def check_rate_limit(self, author_id): + """Check if the user has exceeded their rate limit.""" + current_time = time.time() + request_times = self.user_requests[author_id] + + # Filter out requests that happened more than a minute ago + request_times = [ + timestamp + for timestamp in request_times + if current_time - timestamp < self.LIMIT_WINDOW + ] + + # Update the user's request history with only the recent ones + self.user_requests[author_id] = request_times + + # If the user has made more than the allowed requests in the past minute, deny the request + if len(request_times) >= self.REQUESTS_PER_MINUTE: + return False + + # Otherwise, log the current request + self.user_requests[author_id].append(current_time) + return True + + async def get_random_leetcode_problem(self): + response = requests.get("https://leetcode.com/api/problems/all/") + if response.status_code == 200: + data = response.json() + # This contains the list of problems + problems = data["stat_status_pairs"] + + if problems: + # Select a random problem + random_problem = random.choice(problems) + question_slug = random_problem["stat"]["question__title_slug"] + question_url = f"https://leetcode.com/problems/{question_slug}/" + return question_url + else: + print("No problems found.") + return None + else: + print(f"Failed to retrieve problems: {response.status_code}") + return None + async def prompt_gemini( self, author, input_msg=None, attachment=None, show_input=True ) -> (Embed, Errors): @@ -170,9 +227,26 @@ async def prompt_gemini( return response_embeds, None - async def query(self, author, message=None, attachment=None) -> list[Embed]: - + async def query( + self, author_id, author, message=None, attachment=None + ) -> list[Embed]: response_embeds = [] + # Check the rate limit before processing the query + if not self.check_rate_limit(author_id): + # User exceeded the rate limit + problem_url = await self.get_random_leetcode_problem() + return [ + Embed( + title="Take a break", + description=( + f"Maybe instead of wasting your time spamming, you can do a Leetcode instead 😉\n" + f"{problem_url}" + ), # Added a comma here + color=LIGHT_YELLOW, + ) + ] + + # Process the message and attachment response_image_url = None errors = [] diff --git a/src/main.py b/src/main.py index da44457..049c5ed 100644 --- a/src/main.py +++ b/src/main.py @@ -204,6 +204,7 @@ async def on_message(message: Message): attachment = message.attachments[0] if message.attachments else None bot_response = await client.gemini_model.query( + author_id=message.author.id, author=message.author.display_name, message=message.clean_content.replace("d.chat", ""), attachment=attachment,