From cdb67623c8194d70b37c94d81af3c2f4e68cdc64 Mon Sep 17 00:00:00 2001 From: wayangalihpratama Date: Mon, 9 Dec 2024 12:44:56 +0800 Subject: [PATCH] [#143] Double check non english detection with nltk english stop_words n words --- assistant/assistant.py | 48 +++++++++++++++++++++++++++++-- assistant/requirements.txt | 1 + assistant/tests/test_assistant.py | 11 +++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/assistant/assistant.py b/assistant/assistant.py index 057ef976..c45c2ed9 100644 --- a/assistant/assistant.py +++ b/assistant/assistant.py @@ -3,17 +3,33 @@ import asyncio import logging import chromadb +import nltk from time import sleep from openai import OpenAI -from langdetect import detect + +from langdetect import detect, DetectorFactory +from langdetect.lang_detect_exception import LangDetectException from datetime import datetime, timezone from Akvo_rabbitmq_client import rabbitmq_client from typing import Optional from db import connect_to_sqlite, get_stable_prompt +from nltk.corpus import words, stopwords + logger = logging.getLogger(__name__) +# Ensure reproducibility +DetectorFactory.seed = 0 + +# Ensure the NLTK resources are downloaded +nltk.download("punkt_tab") +nltk.download("words") +nltk.download("stopwords") + +# Get English stopwords and the set of English words +english_stopwords = set(stopwords.words("english")) +english_words = set(w.lower() for w in words.words()) openai = OpenAI() @@ -251,7 +267,29 @@ async def publish_reliably(queue_message: str) -> None: def get_language(user_prompt) -> str: - detected_language = detect(user_prompt) + try: + # Use langdetect to get an initial guess + detected_language = detect(user_prompt) + except LangDetectException: + return "en" + + # Double check to make sure not en + if detected_language != "en": + # Tokenize the text to get individual words + tokens = nltk.word_tokenize(user_prompt.lower()) + + # Count the number of English stopwords and valid English words + stopwords_count = sum(word in english_stopwords for word in tokens) + valid_word_count = sum(word in english_words for word in tokens) + + # Heuristic: check English stopwords or words, reconsider + if ( + stopwords_count > len(tokens) / 2 + or valid_word_count > len(tokens) / 2 + ): + detected_language = "en" + + # check if detected language supported in env settings if detected_language not in assistant_data: logger.warning( f"[ASSISTANT] -> Unsupported lang detected: {detected_language}" @@ -283,6 +321,12 @@ async def on_message(body: str) -> None: user_prompt = from_client["body"] detected_language = get_language(user_prompt) + logger.info(f"[ASSISTANT] -> user prompt: {user_prompt}") + logger.info(f"[ASSISTANT] -> detected lang: {detected_language}") + logger.info( + f"[ASSISTANT] -> assistant_data: {assistant_data[detected_language]}" + ) + knowledge_base = assistant_data[detected_language]["knowledge_base"] system_prompt = assistant_data[detected_language]["system_prompt"] rag_prompt = assistant_data[detected_language]["rag_prompt"] diff --git a/assistant/requirements.txt b/assistant/requirements.txt index b423dc51..448ef17a 100644 --- a/assistant/requirements.txt +++ b/assistant/requirements.txt @@ -4,5 +4,6 @@ pika==1.3.2 black==24.4.2 flake8==7.1.0 langdetect==1.0.9 +nltk==3.9.1 pytest==8.3.2 pandas==1.5.3 \ No newline at end of file diff --git a/assistant/tests/test_assistant.py b/assistant/tests/test_assistant.py index 5200161f..6f51f3ac 100644 --- a/assistant/tests/test_assistant.py +++ b/assistant/tests/test_assistant.py @@ -98,6 +98,17 @@ def test_language_support(): assert rag_prompt == sw_prompt["rag_prompt"] assert ragless_prompt == sw_prompt["ragless_prompt"] + detected_language = get_language("I have question about coffee") + assert detected_language == "en" + knowledge_base = assistant_data[detected_language]["knowledge_base"] + system_prompt = assistant_data[detected_language]["system_prompt"] + rag_prompt = assistant_data[detected_language]["rag_prompt"] + ragless_prompt = assistant_data[detected_language]["ragless_prompt"] + assert knowledge_base.name == "EPPO-datasheets-en" + assert system_prompt == en_prompt["system_prompt"] + assert rag_prompt == en_prompt["rag_prompt"] + assert ragless_prompt == en_prompt["ragless_prompt"] + def test_query_llm(): with patch("assistant.OpenAI") as mock_openai: