From 332a7223616251c6acdb9b49863a3477b0d29eb4 Mon Sep 17 00:00:00 2001 From: Jeremy Pinto Date: Thu, 29 Jun 2023 14:33:50 -0400 Subject: [PATCH] Refactor relevance check (#112) Another refactor to make answer_relevant an actual property. When called, it will trigger a validation automatically. This is just better UX. The user needs to otherwise be aware of the fact that buster.postprocess_completion needs to be called. This PR will now: * explicitly runs all postprocessing computations at the end of a generation * renames variables * accepts answer_text directly as input for a Completion In short, after the text is generated in answer_generator, a postprocess function is automatically executed, computing if necessary answer relevance as well as re-ranking documents --- buster/busterbot.py | 14 +-- buster/completers/base.py | 159 ++++++++++++++++++++++++---------- buster/examples/gradio_app.py | 4 +- buster/validators/base.py | 16 ---- tests/test_chatbot.py | 23 ++--- tests/test_read_write.py | 7 +- 6 files changed, 137 insertions(+), 86 deletions(-) diff --git a/buster/busterbot.py b/buster/busterbot.py index 20e3ed4..81ad5ef 100644 --- a/buster/busterbot.py +++ b/buster/busterbot.py @@ -89,7 +89,12 @@ def process_input(self, user_input: str, source: str = None) -> Completion: if question_relevant: # question is relevant, get completor to generate completion matched_documents = self.retriever.retrieve(user_input, source=source) - completion = self.completer.get_completion(user_input=user_input, matched_documents=matched_documents) + completion = self.completer.get_completion( + user_input=user_input, + matched_documents=matched_documents, + validator=self.validator, + question_relevant=question_relevant, + ) else: # question was determined irrelevant, so we instead return a generic response set by the user. @@ -97,15 +102,12 @@ def process_input(self, user_input: str, source: str = None) -> Completion: error=False, user_input=user_input, matched_documents=pd.DataFrame(), - completor=irrelevant_question_message, + answer_generator=irrelevant_question_message, answer_relevant=False, question_relevant=False, + validator=self.validator, ) logger.info(f"Completion:\n{completion}") return completion - - def postprocess_completion(self, completion) -> Completion: - """This will check if the answer is relevant, and rerank the sources by relevance too.""" - return self.validator.validate(completion=completion) diff --git a/buster/completers/base.py b/buster/completers/base.py index 3ca4cc8..ddcdc34 100644 --- a/buster/completers/base.py +++ b/buster/completers/base.py @@ -32,47 +32,113 @@ openai.api_key = os.environ.get("OPENAI_API_KEY") -@dataclass class Completion: - error: bool - user_input: str - matched_documents: pd.DataFrame - completor: Iterator | str - answer_relevant: bool = None - question_relevant: bool = None - - # private property, should not be set at init - _completor: Iterator | str = field(init=False, repr=False) # e.g. a streamed response from openai.ChatCompletion - _text: str = ( - None # once the generator of the completor is exhausted, the text will be available in the self.text property - ) + def __init__( + self, + error: bool, + user_input: str, + matched_documents: pd.DataFrame, + answer_generator: Optional[Iterator] = None, + answer_text: Optional[str] = None, + answer_relevant: Optional[bool] = None, + question_relevant: Optional[bool] = None, + validator=None, + ): + self.error = error + self.user_input = user_input + self.matched_documents = matched_documents + self.validator = validator + self._answer_relevant = answer_relevant + self._question_relevant = question_relevant + + self._validate_arguments(answer_generator, answer_text) + + def _validate_arguments(self, answer_generator: Optional[Iterator], answer_text: Optional[str]): + """Sets answer_generator and answer_text properties depending on the provided inputs. + + Checks that one of either answer_generator or answer_text is not None. + If answer_text is set, a generator can simply be inferred from answer_text. + If answer_generator is set, answer_text will be set only once the generator gets called. Set to None for now. + """ + if (answer_generator is None and answer_text is None) or ( + answer_generator is not None and answer_text is not None + ): + raise ValueError("Only one of 'answer_generator' and 'answer_text' must be set.") + + # If text is provided, the genrator can be inferred + if answer_text is not None: + assert isinstance(answer_text, str) + answer_generator = (msg for msg in answer_text) + + self._answer_text = answer_text + self._answer_generator = answer_generator @property - def text(self): - if self._text is None: - # generates the text if it wasn't already generated - self._text = "".join([i for i in self.completor]) - return self._text + def answer_relevant(self) -> bool: + """Property determining the relevance of an answer (bool). - @text.setter - def text(self, value: str) -> None: - self._text = value + If an error occured, the relevance is False. + If no documents were retrieved, the relevance is also False. + Otherwise, the relevance is computed as defined by the validator (e.g. comparing to embeddings) + """ + if self.error: + self._answer_relevant = False + elif len(self.matched_documents) == 0: + self._answer_relevant = False + elif self._answer_relevant is not None: + return self._answer_relevant + else: + # Check the answer relevance by looking at the embeddings + self._answer_relevant = self.validator.check_answer_relevance(self.answer_text) + return self._answer_relevant @property - def completor(self): - if isinstance(self._completor, str): - # convert str to iterator - self._completor = (msg for msg in self._completor) + def question_relevant(self): + """Property determining the relevance of the question asked (bool).""" + return self._question_relevant + @property + def answer_text(self): + if self._answer_text is None: + # generates the text if it wasn't already generated + self._answer_text = "".join([i for i in self.answer_generator]) + return self._answer_text + + @answer_text.setter + def answer_text(self, value: str) -> None: + self._answer_text = value + + @property + def answer_generator(self): # keeps track of the yielded text - self._text = "" - for token in self._completor: - self._text += token + self._answer_text = "" + for token in self._answer_generator: + self._answer_text += token yield token - @completor.setter - def completor(self, value: str) -> None: - self._completor = value + self.postprocess() + + @answer_generator.setter + def answer_generator(self, generator: Iterator) -> None: + self._answer_generator = generator + + def postprocess(self): + """Function executed after the answer text is generated by the answer_generator""" + + if self.validator is None: + # TODO: This should only happen if declaring a Completion using .from_dict() method. + # This behaviour is not ideal and we may want to remove support for .from_dict() in the future. + logger.info("No validator was set, skipping postprocessing.") + return + + if self.validator.use_reranking: + # rerank docs in order of cosine similarity to the question + self.matched_documents = self.validator.rerank_docs( + answer=self.answer_text, matched_documents=self.matched_documents + ) + + # access the property so it gets set if not computed alerady + self.answer_relevant def to_json(self, columns_to_ignore: Optional[list[str]] = None) -> Any: """Converts selected attributes of the object to a JSON format. @@ -101,7 +167,7 @@ def encode_df(df: pd.DataFrame) -> dict: to_encode = { "user_input": self.user_input, - "text": self.text, + "answer_text": self.answer_text, "matched_documents": self.matched_documents, "answer_relevant": self.answer_relevant, "question_relevant": self.question_relevant, @@ -118,10 +184,6 @@ def from_dict(cls, completion_dict: dict): else: raise ValueError(f"Unknown type for matched_documents: {type(completion_dict['matched_documents'])}") - # avoids setting a property at init. the .text method will still be available. - completion_dict["completor"] = completion_dict["text"] - del completion_dict["text"] - return cls(**completion_dict) @@ -156,11 +218,12 @@ def prepare_prompt(self, matched_documents) -> str: prompt = self.prompt_formatter.format(formatted_documents) return prompt - def get_completion(self, user_input: str, matched_documents: pd.DataFrame) -> Completion: - """Generate a completion to a user's question based on matched documents.""" + def get_completion( + self, user_input: str, matched_documents: pd.DataFrame, validator, question_relevant: bool = True + ) -> Completion: + """Generate a completion to a user's question based on matched documents. - # The completor assumes a question was previously determined valid, otherwise it would not be called. - question_relevant = True + It is safe to assume the question_relevance to be True if we made it here.""" logger.info(f"{user_input=}") @@ -174,10 +237,11 @@ def get_completion(self, user_input: str, matched_documents: pd.DataFrame) -> Co # because we are proceeding with a completion, we assume the question is relevant. completion = self.completion_class( user_input=user_input, - completor=self.no_documents_message, + answer_generator=self.no_documents_message, error=False, matched_documents=matched_documents, question_relevant=question_relevant, + validator=validator, ) return completion @@ -186,14 +250,15 @@ def get_completion(self, user_input: str, matched_documents: pd.DataFrame) -> Co logger.info(f"{prompt=}") logger.info(f"querying model with parameters: {self.completion_kwargs}...") - completor = self.complete(prompt=prompt, user_input=user_input, **self.completion_kwargs) + answer_generator = self.complete(prompt=prompt, user_input=user_input, **self.completion_kwargs) completion = self.completion_class( - completor=completor, + answer_generator=answer_generator, error=self.error, matched_documents=matched_documents, user_input=user_input, question_relevant=question_relevant, + validator=validator, ) return completion @@ -218,12 +283,12 @@ def complete(self, prompt, user_input, **completion_kwargs): self.error = False if completion_kwargs.get("stream") is True: - def completor(): + def answer_generator(): for chunk in response: token: str = chunk["choices"][0].get("text") yield token - return completor() + return answer_generator() else: return response["choices"][0]["text"] except Exception as e: @@ -249,12 +314,12 @@ def complete(self, prompt: str, user_input, **completion_kwargs) -> str | Iterat if completion_kwargs.get("stream") is True: # We are entering streaming mode, so here were just wrapping the streamed # openai response to be easier to handle later - def completor(): + def answer_generator(): for chunk in response: token: str = chunk["choices"][0]["delta"].get("content", "") yield token - return completor() + return answer_generator() else: full_response: str = response["choices"][0]["message"]["content"] diff --git a/buster/examples/gradio_app.py b/buster/examples/gradio_app.py index 4538673..35f97c9 100644 --- a/buster/examples/gradio_app.py +++ b/buster/examples/gradio_app.py @@ -40,8 +40,6 @@ def format_sources(matched_documents: pd.DataFrame) -> str: def add_sources(history, completion): - completion = buster.postprocess_completion(completion) - if completion.answer_relevant: formatted_sources = format_sources(completion.matched_documents) history.append([None, formatted_sources]) @@ -61,7 +59,7 @@ def chat(history): history[-1][1] = "" - for token in completion.completor: + for token in completion.answer_generator: history[-1][1] += token yield history, completion diff --git a/buster/validators/base.py b/buster/validators/base.py index 255833d..ba1f6f7 100644 --- a/buster/validators/base.py +++ b/buster/validators/base.py @@ -92,19 +92,3 @@ def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFr ) return matched_documents.sort_values(by=col, ascending=False) - - def validate(self, completion: Completion) -> Completion: - if completion.error: - completion.answer_relevant = False - elif len(completion.matched_documents) == 0: - completion.answer_relevant = False - else: - completion.answer_relevant = self.check_answer_relevance(completion.text) - - completion.matched_documents = self.rerank_docs(completion.text, completion.matched_documents) - - return completion - - -def validator_factory(validator_cfg: dict) -> Validator: - return Validator(validator_cfg=validator_cfg) diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 1d4bc4b..9db2c18 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -80,9 +80,13 @@ def prepare_prompt(self, user_input, matched_documents): def complete(self): return - def get_completion(self, user_input, matched_documents) -> Completion: + def get_completion(self, user_input, matched_documents, validator, *arg, **kwarg) -> Completion: return Completion( - completor=self.expected_answer, error=False, user_input=user_input, matched_documents=matched_documents + answer_generator=self.expected_answer, + error=False, + user_input=user_input, + matched_documents=matched_documents, + validator=validator, ) @@ -158,8 +162,8 @@ def test_chatbot_mock_data(tmp_path, monkeypatch): validator = MockValidator(**buster_cfg.validator_cfg) buster = Buster(retriever=retriever, completer=completer, validator=validator) completion = buster.process_input("What is a transformer?", source="fake_source") - assert isinstance(completion.text, str) - assert completion.text.startswith(gpt_expected_answer) + assert isinstance(completion.answer_text, str) + assert completion.answer_text.startswith(gpt_expected_answer) def test_chatbot_real_data__chatGPT(database_file): @@ -177,8 +181,7 @@ def test_chatbot_real_data__chatGPT(database_file): buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator) completion = buster.process_input("What is backpropagation?") - completion = buster.postprocess_completion(completion) - assert isinstance(completion.text, str) + assert isinstance(completion.answer_text, str) assert completion.answer_relevant == True @@ -212,8 +215,7 @@ def test_chatbot_real_data__chatGPT_OOD(database_file): buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator) completion = buster.process_input("What is a good recipe for brocolli soup?") - completion = buster.postprocess_completion(completion) - assert isinstance(completion.text, str) + assert isinstance(completion.answer_text, str) assert completion.answer_relevant == False @@ -239,8 +241,7 @@ def test_chatbot_real_data__no_docs_found(database_file): buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator) completion = buster.process_input("What is backpropagation?") - completion = buster.postprocess_completion(completion) - assert isinstance(completion.text, str) + assert isinstance(completion.answer_text, str) assert completion.answer_relevant == False - assert completion.text == "No documents available." + assert completion.answer_text == "No documents available." diff --git a/tests/test_read_write.py b/tests/test_read_write.py index 2e1a3ab..1e1c745 100644 --- a/tests/test_read_write.py +++ b/tests/test_read_write.py @@ -10,7 +10,7 @@ def __init__(self): def check_answer_relevance(self, completion: Completion) -> bool: return True - def rerank_docs(self, completion: Completion, matched_documents: pd.DataFrame) -> bool: + def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> bool: return matched_documents @@ -29,15 +29,16 @@ def test_read_write_completion(): c = Completion( user_input="What is the meaning of life?", error=False, - completor="This is my completed answer", + answer_generator="This is my actual answer", matched_documents=matched_documents, + validator=MockValidator(), ) c_json = c.to_json() c_back = Completion.from_dict(c_json) assert c.error == c_back.error - assert c.text == c.text + assert c.answer_text == c_back.answer_text assert c.user_input == c_back.user_input assert c.answer_relevant == c_back.answer_relevant for col in c_back.matched_documents.columns.tolist():