Skip to content

Commit

Permalink
Refactor relevance check (#112)
Browse files Browse the repository at this point in the history
Another refactor to make answer_relevant an actual property. When called, it will trigger a validation automatically.
This is just better UX. The user needs to otherwise be aware of the fact that buster.postprocess_completion needs to be called. This PR will now:

* explicitly runs all postprocessing computations at the end of a generation

* renames variables

* accepts answer_text directly as input for a Completion

In short, after the text is generated in answer_generator, a postprocess function is automatically executed, computing if necessary answer relevance as well as re-ranking documents
  • Loading branch information
jerpint authored Jun 29, 2023
1 parent 9b7de7c commit 332a722
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 86 deletions.
14 changes: 8 additions & 6 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,25 @@ def process_input(self, user_input: str, source: str = None) -> Completion:
if question_relevant:
# question is relevant, get completor to generate completion
matched_documents = self.retriever.retrieve(user_input, source=source)
completion = self.completer.get_completion(user_input=user_input, matched_documents=matched_documents)
completion = self.completer.get_completion(
user_input=user_input,
matched_documents=matched_documents,
validator=self.validator,
question_relevant=question_relevant,
)

else:
# question was determined irrelevant, so we instead return a generic response set by the user.
completion = Completion(
error=False,
user_input=user_input,
matched_documents=pd.DataFrame(),
completor=irrelevant_question_message,
answer_generator=irrelevant_question_message,
answer_relevant=False,
question_relevant=False,
validator=self.validator,
)

logger.info(f"Completion:\n{completion}")

return completion

def postprocess_completion(self, completion) -> Completion:
"""This will check if the answer is relevant, and rerank the sources by relevance too."""
return self.validator.validate(completion=completion)
159 changes: 112 additions & 47 deletions buster/completers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,47 +32,113 @@
openai.api_key = os.environ.get("OPENAI_API_KEY")


@dataclass
class Completion:
error: bool
user_input: str
matched_documents: pd.DataFrame
completor: Iterator | str
answer_relevant: bool = None
question_relevant: bool = None

# private property, should not be set at init
_completor: Iterator | str = field(init=False, repr=False) # e.g. a streamed response from openai.ChatCompletion
_text: str = (
None # once the generator of the completor is exhausted, the text will be available in the self.text property
)
def __init__(
self,
error: bool,
user_input: str,
matched_documents: pd.DataFrame,
answer_generator: Optional[Iterator] = None,
answer_text: Optional[str] = None,
answer_relevant: Optional[bool] = None,
question_relevant: Optional[bool] = None,
validator=None,
):
self.error = error
self.user_input = user_input
self.matched_documents = matched_documents
self.validator = validator
self._answer_relevant = answer_relevant
self._question_relevant = question_relevant

self._validate_arguments(answer_generator, answer_text)

def _validate_arguments(self, answer_generator: Optional[Iterator], answer_text: Optional[str]):
"""Sets answer_generator and answer_text properties depending on the provided inputs.
Checks that one of either answer_generator or answer_text is not None.
If answer_text is set, a generator can simply be inferred from answer_text.
If answer_generator is set, answer_text will be set only once the generator gets called. Set to None for now.
"""
if (answer_generator is None and answer_text is None) or (
answer_generator is not None and answer_text is not None
):
raise ValueError("Only one of 'answer_generator' and 'answer_text' must be set.")

# If text is provided, the genrator can be inferred
if answer_text is not None:
assert isinstance(answer_text, str)
answer_generator = (msg for msg in answer_text)

self._answer_text = answer_text
self._answer_generator = answer_generator

@property
def text(self):
if self._text is None:
# generates the text if it wasn't already generated
self._text = "".join([i for i in self.completor])
return self._text
def answer_relevant(self) -> bool:
"""Property determining the relevance of an answer (bool).
@text.setter
def text(self, value: str) -> None:
self._text = value
If an error occured, the relevance is False.
If no documents were retrieved, the relevance is also False.
Otherwise, the relevance is computed as defined by the validator (e.g. comparing to embeddings)
"""
if self.error:
self._answer_relevant = False
elif len(self.matched_documents) == 0:
self._answer_relevant = False
elif self._answer_relevant is not None:
return self._answer_relevant
else:
# Check the answer relevance by looking at the embeddings
self._answer_relevant = self.validator.check_answer_relevance(self.answer_text)
return self._answer_relevant

@property
def completor(self):
if isinstance(self._completor, str):
# convert str to iterator
self._completor = (msg for msg in self._completor)
def question_relevant(self):
"""Property determining the relevance of the question asked (bool)."""
return self._question_relevant

@property
def answer_text(self):
if self._answer_text is None:
# generates the text if it wasn't already generated
self._answer_text = "".join([i for i in self.answer_generator])
return self._answer_text

@answer_text.setter
def answer_text(self, value: str) -> None:
self._answer_text = value

@property
def answer_generator(self):
# keeps track of the yielded text
self._text = ""
for token in self._completor:
self._text += token
self._answer_text = ""
for token in self._answer_generator:
self._answer_text += token
yield token

@completor.setter
def completor(self, value: str) -> None:
self._completor = value
self.postprocess()

@answer_generator.setter
def answer_generator(self, generator: Iterator) -> None:
self._answer_generator = generator

def postprocess(self):
"""Function executed after the answer text is generated by the answer_generator"""

if self.validator is None:
# TODO: This should only happen if declaring a Completion using .from_dict() method.
# This behaviour is not ideal and we may want to remove support for .from_dict() in the future.
logger.info("No validator was set, skipping postprocessing.")
return

if self.validator.use_reranking:
# rerank docs in order of cosine similarity to the question
self.matched_documents = self.validator.rerank_docs(
answer=self.answer_text, matched_documents=self.matched_documents
)

# access the property so it gets set if not computed alerady
self.answer_relevant

def to_json(self, columns_to_ignore: Optional[list[str]] = None) -> Any:
"""Converts selected attributes of the object to a JSON format.
Expand Down Expand Up @@ -101,7 +167,7 @@ def encode_df(df: pd.DataFrame) -> dict:

to_encode = {
"user_input": self.user_input,
"text": self.text,
"answer_text": self.answer_text,
"matched_documents": self.matched_documents,
"answer_relevant": self.answer_relevant,
"question_relevant": self.question_relevant,
Expand All @@ -118,10 +184,6 @@ def from_dict(cls, completion_dict: dict):
else:
raise ValueError(f"Unknown type for matched_documents: {type(completion_dict['matched_documents'])}")

# avoids setting a property at init. the .text method will still be available.
completion_dict["completor"] = completion_dict["text"]
del completion_dict["text"]

return cls(**completion_dict)


Expand Down Expand Up @@ -156,11 +218,12 @@ def prepare_prompt(self, matched_documents) -> str:
prompt = self.prompt_formatter.format(formatted_documents)
return prompt

def get_completion(self, user_input: str, matched_documents: pd.DataFrame) -> Completion:
"""Generate a completion to a user's question based on matched documents."""
def get_completion(
self, user_input: str, matched_documents: pd.DataFrame, validator, question_relevant: bool = True
) -> Completion:
"""Generate a completion to a user's question based on matched documents.
# The completor assumes a question was previously determined valid, otherwise it would not be called.
question_relevant = True
It is safe to assume the question_relevance to be True if we made it here."""

logger.info(f"{user_input=}")

Expand All @@ -174,10 +237,11 @@ def get_completion(self, user_input: str, matched_documents: pd.DataFrame) -> Co
# because we are proceeding with a completion, we assume the question is relevant.
completion = self.completion_class(
user_input=user_input,
completor=self.no_documents_message,
answer_generator=self.no_documents_message,
error=False,
matched_documents=matched_documents,
question_relevant=question_relevant,
validator=validator,
)
return completion

Expand All @@ -186,14 +250,15 @@ def get_completion(self, user_input: str, matched_documents: pd.DataFrame) -> Co
logger.info(f"{prompt=}")

logger.info(f"querying model with parameters: {self.completion_kwargs}...")
completor = self.complete(prompt=prompt, user_input=user_input, **self.completion_kwargs)
answer_generator = self.complete(prompt=prompt, user_input=user_input, **self.completion_kwargs)

completion = self.completion_class(
completor=completor,
answer_generator=answer_generator,
error=self.error,
matched_documents=matched_documents,
user_input=user_input,
question_relevant=question_relevant,
validator=validator,
)

return completion
Expand All @@ -218,12 +283,12 @@ def complete(self, prompt, user_input, **completion_kwargs):
self.error = False
if completion_kwargs.get("stream") is True:

def completor():
def answer_generator():
for chunk in response:
token: str = chunk["choices"][0].get("text")
yield token

return completor()
return answer_generator()
else:
return response["choices"][0]["text"]
except Exception as e:
Expand All @@ -249,12 +314,12 @@ def complete(self, prompt: str, user_input, **completion_kwargs) -> str | Iterat
if completion_kwargs.get("stream") is True:
# We are entering streaming mode, so here were just wrapping the streamed
# openai response to be easier to handle later
def completor():
def answer_generator():
for chunk in response:
token: str = chunk["choices"][0]["delta"].get("content", "")
yield token

return completor()
return answer_generator()

else:
full_response: str = response["choices"][0]["message"]["content"]
Expand Down
4 changes: 1 addition & 3 deletions buster/examples/gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def format_sources(matched_documents: pd.DataFrame) -> str:


def add_sources(history, completion):
completion = buster.postprocess_completion(completion)

if completion.answer_relevant:
formatted_sources = format_sources(completion.matched_documents)
history.append([None, formatted_sources])
Expand All @@ -61,7 +59,7 @@ def chat(history):

history[-1][1] = ""

for token in completion.completor:
for token in completion.answer_generator:
history[-1][1] += token

yield history, completion
Expand Down
16 changes: 0 additions & 16 deletions buster/validators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,3 @@ def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFr
)

return matched_documents.sort_values(by=col, ascending=False)

def validate(self, completion: Completion) -> Completion:
if completion.error:
completion.answer_relevant = False
elif len(completion.matched_documents) == 0:
completion.answer_relevant = False
else:
completion.answer_relevant = self.check_answer_relevance(completion.text)

completion.matched_documents = self.rerank_docs(completion.text, completion.matched_documents)

return completion


def validator_factory(validator_cfg: dict) -> Validator:
return Validator(validator_cfg=validator_cfg)
23 changes: 12 additions & 11 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,13 @@ def prepare_prompt(self, user_input, matched_documents):
def complete(self):
return

def get_completion(self, user_input, matched_documents) -> Completion:
def get_completion(self, user_input, matched_documents, validator, *arg, **kwarg) -> Completion:
return Completion(
completor=self.expected_answer, error=False, user_input=user_input, matched_documents=matched_documents
answer_generator=self.expected_answer,
error=False,
user_input=user_input,
matched_documents=matched_documents,
validator=validator,
)


Expand Down Expand Up @@ -158,8 +162,8 @@ def test_chatbot_mock_data(tmp_path, monkeypatch):
validator = MockValidator(**buster_cfg.validator_cfg)
buster = Buster(retriever=retriever, completer=completer, validator=validator)
completion = buster.process_input("What is a transformer?", source="fake_source")
assert isinstance(completion.text, str)
assert completion.text.startswith(gpt_expected_answer)
assert isinstance(completion.answer_text, str)
assert completion.answer_text.startswith(gpt_expected_answer)


def test_chatbot_real_data__chatGPT(database_file):
Expand All @@ -177,8 +181,7 @@ def test_chatbot_real_data__chatGPT(database_file):
buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator)

completion = buster.process_input("What is backpropagation?")
completion = buster.postprocess_completion(completion)
assert isinstance(completion.text, str)
assert isinstance(completion.answer_text, str)

assert completion.answer_relevant == True

Expand Down Expand Up @@ -212,8 +215,7 @@ def test_chatbot_real_data__chatGPT_OOD(database_file):
buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator)

completion = buster.process_input("What is a good recipe for brocolli soup?")
completion = buster.postprocess_completion(completion)
assert isinstance(completion.text, str)
assert isinstance(completion.answer_text, str)

assert completion.answer_relevant == False

Expand All @@ -239,8 +241,7 @@ def test_chatbot_real_data__no_docs_found(database_file):
buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator)

completion = buster.process_input("What is backpropagation?")
completion = buster.postprocess_completion(completion)
assert isinstance(completion.text, str)
assert isinstance(completion.answer_text, str)

assert completion.answer_relevant == False
assert completion.text == "No documents available."
assert completion.answer_text == "No documents available."
7 changes: 4 additions & 3 deletions tests/test_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self):
def check_answer_relevance(self, completion: Completion) -> bool:
return True

def rerank_docs(self, completion: Completion, matched_documents: pd.DataFrame) -> bool:
def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> bool:
return matched_documents


Expand All @@ -29,15 +29,16 @@ def test_read_write_completion():
c = Completion(
user_input="What is the meaning of life?",
error=False,
completor="This is my completed answer",
answer_generator="This is my actual answer",
matched_documents=matched_documents,
validator=MockValidator(),
)

c_json = c.to_json()
c_back = Completion.from_dict(c_json)

assert c.error == c_back.error
assert c.text == c.text
assert c.answer_text == c_back.answer_text
assert c.user_input == c_back.user_input
assert c.answer_relevant == c_back.answer_relevant
for col in c_back.matched_documents.columns.tolist():
Expand Down

0 comments on commit 332a722

Please sign in to comment.