From 9b7de7cc78e5ae3e21c5702f968b6df1241d02e4 Mon Sep 17 00:00:00 2001 From: Jeremy Pinto Date: Tue, 27 Jun 2023 13:31:21 -0400 Subject: [PATCH] Add question pre-processing support (#108) * add support for verifying validity of questions before generating a response * skip reranking when documents are empty * pass message along with question when checking relevance * add a default invalid question response --- buster/busterbot.py | 22 ++++++++++++++++++++-- buster/completers/base.py | 31 +++++++++++++++++++++++-------- buster/examples/cfg.py | 1 + buster/validators/base.py | 22 +++++++++++++++++++++- 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/buster/busterbot.py b/buster/busterbot.py index 51d8d2a..20e3ed4 100644 --- a/buster/busterbot.py +++ b/buster/busterbot.py @@ -2,6 +2,8 @@ from dataclasses import dataclass, field from typing import Any +import pandas as pd + from buster.completers import Completer, Completion from buster.retriever import Retriever from buster.validators import Validator @@ -81,13 +83,29 @@ def process_input(self, user_input: str, source: str = None) -> Completion: if not user_input.endswith("\n"): user_input += "\n" - matched_documents = self.retriever.retrieve(user_input, source=source) + # The returned message is either a generic invalid question message or an error handling message + question_relevant, irrelevant_question_message = self.validator.check_question_relevance(user_input) + + if question_relevant: + # question is relevant, get completor to generate completion + matched_documents = self.retriever.retrieve(user_input, source=source) + completion = self.completer.get_completion(user_input=user_input, matched_documents=matched_documents) - completion = self.completer.get_completion(user_input=user_input, matched_documents=matched_documents) + else: + # question was determined irrelevant, so we instead return a generic response set by the user. + completion = Completion( + error=False, + user_input=user_input, + matched_documents=pd.DataFrame(), + completor=irrelevant_question_message, + answer_relevant=False, + question_relevant=False, + ) logger.info(f"Completion:\n{completion}") return completion def postprocess_completion(self, completion) -> Completion: + """This will check if the answer is relevant, and rerank the sources by relevance too.""" return self.validator.validate(completion=completion) diff --git a/buster/completers/base.py b/buster/completers/base.py index 70df0bd..3ca4cc8 100644 --- a/buster/completers/base.py +++ b/buster/completers/base.py @@ -39,10 +39,13 @@ class Completion: matched_documents: pd.DataFrame completor: Iterator | str answer_relevant: bool = None + question_relevant: bool = None # private property, should not be set at init _completor: Iterator | str = field(init=False, repr=False) # e.g. a streamed response from openai.ChatCompletion - _text: str = None + _text: str = ( + None # once the generator of the completor is exhausted, the text will be available in the self.text property + ) @property def text(self): @@ -101,6 +104,7 @@ def encode_df(df: pd.DataFrame) -> dict: "text": self.text, "matched_documents": self.matched_documents, "answer_relevant": self.answer_relevant, + "question_relevant": self.question_relevant, "error": self.error, } return jsonable_encoder(to_encode, custom_encoder=custom_encoder) @@ -128,11 +132,13 @@ def __init__( prompt_formatter: PromptFormatter, completion_kwargs: dict, no_documents_message: str = "No documents were found that match your question.", + completion_class: Completion = Completion, ): self.completion_kwargs = completion_kwargs self.documents_formatter = documents_formatter self.prompt_formatter = prompt_formatter self.no_documents_message = no_documents_message + self.completion_class = completion_class @abstractmethod def complete(self, prompt: str, user_input: str) -> Completion: @@ -151,34 +157,43 @@ def prepare_prompt(self, matched_documents) -> str: return prompt def get_completion(self, user_input: str, matched_documents: pd.DataFrame) -> Completion: - # Call the API to generate a response + """Generate a completion to a user's question based on matched documents.""" + + # The completor assumes a question was previously determined valid, otherwise it would not be called. + question_relevant = True logger.info(f"{user_input=}") if len(matched_documents) == 0: - logger.warning("no documents found...") # no document was found, pass the appropriate message instead... + logger.warning("no documents found...") # empty dataframe matched_documents = pd.DataFrame(columns=matched_documents.columns) - completion = Completion( + # because we are proceeding with a completion, we assume the question is relevant. + completion = self.completion_class( user_input=user_input, completor=self.no_documents_message, error=False, matched_documents=matched_documents, + question_relevant=question_relevant, ) return completion - # prepare the prompt + # prepare the prompt with matched documents prompt = self.prepare_prompt(matched_documents) logger.info(f"{prompt=}") logger.info(f"querying model with parameters: {self.completion_kwargs}...") completor = self.complete(prompt=prompt, user_input=user_input, **self.completion_kwargs) - completion = Completion( - completor=completor, error=self.error, matched_documents=matched_documents, user_input=user_input + completion = self.completion_class( + completor=completor, + error=self.error, + matched_documents=matched_documents, + user_input=user_input, + question_relevant=question_relevant, ) return completion @@ -224,13 +239,13 @@ def complete(self, prompt: str, user_input, **completion_kwargs) -> str | Iterat {"role": "system", "content": prompt}, {"role": "user", "content": user_input}, ] + self.error = False try: response = openai.ChatCompletion.create( messages=messages, **completion_kwargs, ) - self.error = False if completion_kwargs.get("stream") is True: # We are entering streaming mode, so here were just wrapping the streamed # openai response to be easier to handle later diff --git a/buster/examples/cfg.py b/buster/examples/cfg.py index cd6a1ce..612b6dd 100644 --- a/buster/examples/cfg.py +++ b/buster/examples/cfg.py @@ -6,6 +6,7 @@ "unknown_threshold": 0.85, "embedding_model": "text-embedding-ada-002", "use_reranking": True, + "invalid_question_response": "This question does not seem relevant to my current knowledge.", }, retriever_cfg={ "db_path": "documents.db", diff --git a/buster/validators/base.py b/buster/validators/base.py index 826c37c..255833d 100644 --- a/buster/validators/base.py +++ b/buster/validators/base.py @@ -11,11 +11,19 @@ class Validator: - def __init__(self, embedding_model: str, unknown_threshold: float, unknown_prompt: str, use_reranking: bool): + def __init__( + self, + embedding_model: str, + unknown_threshold: float, + unknown_prompt: str, + use_reranking: bool, + invalid_question_response: str = "This question is not relevant to my knowledge.", + ): self.embedding_model = embedding_model self.unknown_threshold = unknown_threshold self.unknown_prompt = unknown_prompt self.use_reranking = use_reranking + self.invalid_question_response = invalid_question_response @staticmethod @lru_cache @@ -23,6 +31,16 @@ def get_embedding(query: str, engine: str): logger.info("generating embedding") return get_embedding(query, engine=engine) + def check_question_relevance(self, question: str) -> tuple[bool, str]: + """Determines wether a question is relevant or not for our given framework.""" + # Override this method to suit your needs. + # By default, no checks happen. + # You could for example use a GPT call to check your question validity, at extra cost/latency. + # The message will be what's printed should question_relevant be False. + question_relevant = True + message: str = self.invalid_question_response + return question_relevant, message + def check_answer_relevance(self, answer: str, unknown_prompt: str = None) -> bool: """Check to see if a generated answer is relevant to the chatbot's knowledge or not. @@ -60,6 +78,8 @@ def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFr This score could be used to determine wether a document was actually relevant to generation. An extra column is added in-place for the similarity score. """ + if len(matched_documents) == 0: + return matched_documents logger.info("Reranking documents based on answer similarity...") answer_embedding = self.get_embedding(