Skip to content

Commit

Permalink
feat(chat): Add Rate Limit (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
phoenixpereira authored Sep 25, 2024
1 parent ac4601a commit 36aea0e
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .example.env
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ SKULLBOARD_CHANNEL_ID=SKULLBOARD_CHANNEL_ID
REQUIRED_REACTIONS=5
TENOR_API_KEY="TENOR_API_KEY"
GEMINI_API_KEY="GEMINI_API_KEY"
REQUESTS_PER_MINUTE=3
LIMIT_WINDOW=60
4 changes: 4 additions & 0 deletions .github/workflows/production.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ jobs:
REQUIRED_REACTIONS: ${{ secrets.REQUIRED_REACTIONS }}
TENOR_API_KEY: ${{ secrets.TENOR_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
REQUESTS_PER_MINUTE: ${{ secrets.REQUESTS_PER_MINUTE }}
LIMIT_WINDOW: ${{ secrets.LIMIT_WINDOW }}
run: |
echo "$KEY" > private_key && chmod 600 private_key
ssh -v -o StrictHostKeyChecking=no -i private_key ${USER}@${HOSTNAME} '
Expand All @@ -94,6 +96,8 @@ jobs:
echo REQUIRED_REACTIONS=${{ secrets.REQUIRED_REACTIONS }} >> .env
echo TENOR_API_KEY=${{ secrets.TENOR_API_KEY }} >> .env
echo GEMINI_API_KEY=${{ secrets.GEMINI_API_KEY }} >> .env
echo REQUESTS_PER_MINUTE=${{ secrets.REQUESTS_PER_MINUTE }} >> .env
echo LIMIT_WINDOW=${{ secrets.LIMIT_WINDOW }} >> .env
docker load -i duckbot.tar.gz
docker compose up -d
'
78 changes: 76 additions & 2 deletions src/commands/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@
import os.path
import re
import tempfile
import time
from collections import defaultdict
import random
import requests

from discord import Embed
from google.generativeai.types import HarmCategory, HarmBlockThreshold, File
import google.generativeai as genai
from dotenv import load_dotenv

from constants.colours import LIGHT_YELLOW

# Load environment variables from .env file
load_dotenv()

SAFETY_SETTINGS = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
Expand Down Expand Up @@ -51,9 +59,15 @@ class Errors(IntEnum):


class GeminiBot:
REQUESTS_PER_MINUTE = int(os.environ["REQUESTS_PER_MINUTE"])
LIMIT_WINDOW = int(os.environ["LIMIT_WINDOW"])

def __init__(self, model_name, data_csv_path, bot, api_key):
genai.configure(api_key=api_key)

# Dictionary to track users and their request timestamps
self.user_requests = defaultdict(list)

system_instruction = (
"You are DuckBot, the official discord bot for the Computer Science Club of the University of Adelaide. "
"Your main purpose is to answer CS questions and FAQs by users. "
Expand Down Expand Up @@ -102,6 +116,49 @@ def __init__(self, model_name, data_csv_path, bot, api_key):
# Gemini API provides a chat option to maintain a conversation
self.chat = self.model.start_chat()

def check_rate_limit(self, author_id):
"""Check if the user has exceeded their rate limit."""
current_time = time.time()
request_times = self.user_requests[author_id]

# Filter out requests that happened more than a minute ago
request_times = [
timestamp
for timestamp in request_times
if current_time - timestamp < self.LIMIT_WINDOW
]

# Update the user's request history with only the recent ones
self.user_requests[author_id] = request_times

# If the user has made more than the allowed requests in the past minute, deny the request
if len(request_times) >= self.REQUESTS_PER_MINUTE:
return False

# Otherwise, log the current request
self.user_requests[author_id].append(current_time)
return True

async def get_random_leetcode_problem(self):
response = requests.get("https://leetcode.com/api/problems/all/")
if response.status_code == 200:
data = response.json()
# This contains the list of problems
problems = data["stat_status_pairs"]

if problems:
# Select a random problem
random_problem = random.choice(problems)
question_slug = random_problem["stat"]["question__title_slug"]
question_url = f"https://leetcode.com/problems/{question_slug}/"
return question_url
else:
print("No problems found.")
return None
else:
print(f"Failed to retrieve problems: {response.status_code}")
return None

async def prompt_gemini(
self, author, input_msg=None, attachment=None, show_input=True
) -> (Embed, Errors):
Expand Down Expand Up @@ -170,9 +227,26 @@ async def prompt_gemini(

return response_embeds, None

async def query(self, author, message=None, attachment=None) -> list[Embed]:

async def query(
self, author_id, author, message=None, attachment=None
) -> list[Embed]:
response_embeds = []
# Check the rate limit before processing the query
if not self.check_rate_limit(author_id):
# User exceeded the rate limit
problem_url = await self.get_random_leetcode_problem()
return [
Embed(
title="Take a break",
description=(
f"Maybe instead of wasting your time spamming, you can do a Leetcode instead 😉\n"
f"{problem_url}"
), # Added a comma here
color=LIGHT_YELLOW,
)
]

# Process the message and attachment
response_image_url = None
errors = []

Expand Down
1 change: 1 addition & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ async def on_message(message: Message):
attachment = message.attachments[0] if message.attachments else None

bot_response = await client.gemini_model.query(
author_id=message.author.id,
author=message.author.display_name,
message=message.clean_content.replace("d.chat", ""),
attachment=attachment,
Expand Down

0 comments on commit 36aea0e

Please sign in to comment.