Skip to content

Commit

Permalink
[#143] Double check non english detection with nltk english stop_word…
Browse files Browse the repository at this point in the history
…s n words
  • Loading branch information
wayangalihpratama committed Dec 9, 2024
1 parent 4ce1a79 commit cdb6762
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 2 deletions.
48 changes: 46 additions & 2 deletions assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,33 @@
import asyncio
import logging
import chromadb
import nltk

from time import sleep
from openai import OpenAI
from langdetect import detect

from langdetect import detect, DetectorFactory
from langdetect.lang_detect_exception import LangDetectException
from datetime import datetime, timezone
from Akvo_rabbitmq_client import rabbitmq_client
from typing import Optional
from db import connect_to_sqlite, get_stable_prompt
from nltk.corpus import words, stopwords


logger = logging.getLogger(__name__)

# Ensure reproducibility
DetectorFactory.seed = 0

# Ensure the NLTK resources are downloaded
nltk.download("punkt_tab")
nltk.download("words")
nltk.download("stopwords")

# Get English stopwords and the set of English words
english_stopwords = set(stopwords.words("english"))
english_words = set(w.lower() for w in words.words())

openai = OpenAI()

Expand Down Expand Up @@ -251,7 +267,29 @@ async def publish_reliably(queue_message: str) -> None:


def get_language(user_prompt) -> str:
detected_language = detect(user_prompt)
try:
# Use langdetect to get an initial guess
detected_language = detect(user_prompt)
except LangDetectException:
return "en"

# Double check to make sure not en
if detected_language != "en":
# Tokenize the text to get individual words
tokens = nltk.word_tokenize(user_prompt.lower())

# Count the number of English stopwords and valid English words
stopwords_count = sum(word in english_stopwords for word in tokens)
valid_word_count = sum(word in english_words for word in tokens)

# Heuristic: check English stopwords or words, reconsider
if (
stopwords_count > len(tokens) / 2
or valid_word_count > len(tokens) / 2
):
detected_language = "en"

# check if detected language supported in env settings
if detected_language not in assistant_data:
logger.warning(
f"[ASSISTANT] -> Unsupported lang detected: {detected_language}"
Expand Down Expand Up @@ -283,6 +321,12 @@ async def on_message(body: str) -> None:
user_prompt = from_client["body"]

detected_language = get_language(user_prompt)
logger.info(f"[ASSISTANT] -> user prompt: {user_prompt}")
logger.info(f"[ASSISTANT] -> detected lang: {detected_language}")
logger.info(
f"[ASSISTANT] -> assistant_data: {assistant_data[detected_language]}"
)

knowledge_base = assistant_data[detected_language]["knowledge_base"]
system_prompt = assistant_data[detected_language]["system_prompt"]
rag_prompt = assistant_data[detected_language]["rag_prompt"]
Expand Down
1 change: 1 addition & 0 deletions assistant/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ pika==1.3.2
black==24.4.2
flake8==7.1.0
langdetect==1.0.9
nltk==3.9.1
pytest==8.3.2
pandas==1.5.3
11 changes: 11 additions & 0 deletions assistant/tests/test_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ def test_language_support():
assert rag_prompt == sw_prompt["rag_prompt"]
assert ragless_prompt == sw_prompt["ragless_prompt"]

detected_language = get_language("I have question about coffee")
assert detected_language == "en"
knowledge_base = assistant_data[detected_language]["knowledge_base"]
system_prompt = assistant_data[detected_language]["system_prompt"]
rag_prompt = assistant_data[detected_language]["rag_prompt"]
ragless_prompt = assistant_data[detected_language]["ragless_prompt"]
assert knowledge_base.name == "EPPO-datasheets-en"
assert system_prompt == en_prompt["system_prompt"]
assert rag_prompt == en_prompt["rag_prompt"]
assert ragless_prompt == en_prompt["ragless_prompt"]


def test_query_llm():
with patch("assistant.OpenAI") as mock_openai:
Expand Down

0 comments on commit cdb6762

Please sign in to comment.