Skip to content

Commit

Permalink
Update documents formatting (#130)
Browse files Browse the repository at this point in the history
* Add json formatter

* add base class to doc formatters

* update docstrings

* add tests
  • Loading branch information
jerpint authored Sep 15, 2023
1 parent d3b31e0 commit ba76a84
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 37 deletions.
10 changes: 4 additions & 6 deletions buster/examples/cfg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from buster.busterbot import Buster, BusterConfig
from buster.completers import ChatGPTCompleter, Completer, DocumentAnswerer
from buster.formatters.documents import DocumentsFormatter
from buster.completers import ChatGPTCompleter, DocumentAnswerer
from buster.formatters.documents import DocumentsFormatterJSON
from buster.formatters.prompts import PromptFormatter
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers import GPTTokenizer
Expand Down Expand Up @@ -58,7 +58,7 @@
},
documents_formatter_cfg={
"max_tokens": 3500,
"formatter": "{content}",
"columns": ["content", "title", "source"],
},
prompt_formatter_cfg={
"max_tokens": 3500,
Expand All @@ -69,10 +69,8 @@
"If it isn't, simply reply that you cannot answer the question. "
"Do not refer to the documentation directly, but use the instructions provided within it to answer questions. "
"Here is the documentation: "
"<DOCUMENTS> "
),
"text_after_docs": (
"<\DOCUMENTS>\n"
"REMEMBER:\n"
"You are a chatbot assistant answering technical questions about artificial intelligence (AI)."
"Here are the rules you must follow:\n"
Expand All @@ -97,7 +95,7 @@ def setup_buster(buster_cfg: BusterConfig):
tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg)
document_answerer: DocumentAnswerer = DocumentAnswerer(
completer=ChatGPTCompleter(**buster_cfg.completion_cfg),
documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
documents_formatter=DocumentsFormatterJSON(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
**buster_cfg.documents_answerer_cfg,
)
Expand Down
121 changes: 106 additions & 15 deletions buster/formatters/documents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass

import pandas as pd
Expand All @@ -9,38 +10,81 @@
logging.basicConfig(level=logging.INFO)


class DocumentsFormatter(ABC):
"""
Abstract base class for document formatters.
Subclasses are required to implement the `format` method which transforms the input documents
into the desired format.
"""

@abstractmethod
def format(self, matched_documents: pd.DataFrame) -> tuple[str, pd.DataFrame]:
"""
Abstract method to format matched documents.
Parameters:
- matched_documents (pd.DataFrame): DataFrame containing the matched documents to be formatted.
Returns:
- tuple[str, pd.DataFrame]: A tuple containing the formatted documents as a string and
the possibly truncated matched documents DataFrame.
"""
pass


@dataclass
class DocumentsFormatter:
class DocumentsFormatterHTML(DocumentsFormatter):
"""
Formatter class to convert matched documents into an HTML format.
Attributes:
- tokenizer (Tokenizer): Tokenizer instance to count tokens in the documents.
- max_tokens (int): Maximum allowed tokens for the formatted documents.
- formatter (str): String formatter for the document's content.
- inner_tag (str): HTML tag that will be used at the document level.
- outer_tag (str): HTML tag that will be used at the documents collection level.
"""

tokenizer: Tokenizer
max_tokens: int
formatter: str = "{content}"
inner_tag: str = "DOCUMENT"
outer_tag: str = "DOCUMENTS"

def format(self, matched_documents: pd.DataFrame) -> tuple[str, pd.DataFrame]:
"""
Format the matched documents into an HTML format.
def format(
self,
matched_documents: pd.DataFrame,
) -> tuple[str, pd.DataFrame]:
"""Format our matched documents to plaintext.
If the total tokens exceed max_tokens, documents are truncated or omitted to fit within the limit.
We also make sure they fit in the alloted max_tokens space.
Parameters:
- matched_documents (pd.DataFrame): DataFrame containing the matched documents to be formatted.
Returns:
- tuple[str, pd.DataFrame]: A tuple containing the formatted documents as an HTML string and
the possibly truncated matched documents DataFrame.
"""

documents_str = ""
total_tokens = 0
max_tokens = self.max_tokens

num_total_docs = len(matched_documents)
num_preserved_docs = 0
# TODO: uniformize this logic with the DocumentsFormatterJSON
for _, row in matched_documents.iterrows():
doc = self.formatter.format_map(row.to_dict())
num_preserved_docs += 1
token_count, encoded = self.tokenizer.num_tokens(doc, return_encoded=True)
if total_tokens + token_count <= max_tokens:
documents_str += f"<DOCUMENT>{doc}<\\DOCUMENT>"
documents_str += f"<{self.inner_tag}>{doc}<\\{self.inner_tag}>"
total_tokens += token_count
else:
logger.warning("truncating document to fit...")
remaining_tokens = max_tokens - total_tokens
truncated_doc = self.tokenizer.decode(encoded[:remaining_tokens])
documents_str += f"<DOCUMENT>{truncated_doc}<\\DOCUMENT>"
documents_str += f"<{self.inner_tag}>{truncated_doc}<\\{self.inner_tag}>"
logger.warning(f"Documents after truncation: {documents_str}")
break

Expand All @@ -50,12 +94,59 @@ def format(
)
matched_documents = matched_documents.iloc[:num_preserved_docs]

documents_str = f"<{self.outer_tag}>{documents_str}<\\{self.outer_tag}>"

return documents_str, matched_documents


def documents_formatter_factory(tokenizer: Tokenizer, max_tokens: int, formatter: str) -> DocumentsFormatter:
return DocumentsFormatter(
tokenizer=tokenizer,
max_tokens=max_tokens,
formatter=formatter,
)
@dataclass
class DocumentsFormatterJSON(DocumentsFormatter):
"""
Formatter class to convert matched documents into a JSON format.
Attributes:
- tokenizer (Tokenizer): Tokenizer instance to count tokens in the documents.
- max_tokens (int): Maximum allowed tokens for the formatted documents.
- columns (list[str]): List of columns to include in the JSON format.
"""

tokenizer: Tokenizer
max_tokens: int
columns: list[str]

def format(self, matched_documents: pd.DataFrame) -> tuple[str, pd.DataFrame]:
"""
Format the matched documents into a JSON format.
If the total tokens exceed max_tokens, documents are omitted one at a time until it fits the limit.
Parameters:
- matched_documents (pd.DataFrame): DataFrame containing the matched documents to be formatted.
Returns:
- tuple[str, pd.DataFrame]: A tuple containing the formatted documents as a JSON string and
the possibly truncated matched documents DataFrame.
"""

max_tokens = self.max_tokens
documents_str = matched_documents[self.columns].to_json(orient="records")
token_count, _ = self.tokenizer.num_tokens(documents_str, return_encoded=True)

while token_count > max_tokens:
# Truncated too much, no documents left, raise an error
if len(matched_documents) == 0:
raise ValueError(
f"Could not truncate documents to fit {max_tokens=}. Consider increasing max_tokens or decreasing chunk lengths."
)

# Too many tokens, drop a document and try again.
matched_documents = matched_documents.iloc[:-1]
documents_str = matched_documents[self.columns].to_json(orient="records")
token_count, _ = self.tokenizer.num_tokens(documents_str, return_encoded=True)

# Log a warning with more details
logger.warning(
f"Truncating documents to fit. Remaining documents after truncation: {len(matched_documents)}"
)

return documents_str, matched_documents
8 changes: 4 additions & 4 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from buster.busterbot import Buster, BusterConfig
from buster.completers import ChatGPTCompleter, Completer, Completion, DocumentAnswerer
from buster.documents_manager import DeepLakeDocumentsManager
from buster.formatters.documents import DocumentsFormatter
from buster.formatters.documents import DocumentsFormatterHTML
from buster.formatters.prompts import PromptFormatter
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers.gpt import GPTTokenizer
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_chatbot_real_data__chatGPT(vector_store_path):
tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg)
document_answerer = DocumentAnswerer(
completer=ChatGPTCompleter(**buster_cfg.completion_cfg),
documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
documents_formatter=DocumentsFormatterHTML(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
)
validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg)
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_chatbot_real_data__chatGPT_OOD(vector_store_path):
tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg)
document_answerer = DocumentAnswerer(
completer=ChatGPTCompleter(**buster_cfg.completion_cfg),
documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
documents_formatter=DocumentsFormatterHTML(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
)
validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg)
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_chatbot_real_data__no_docs_found(vector_store_path):
tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg)
document_answerer = DocumentAnswerer(
completer=ChatGPTCompleter(**buster_cfg.completion_cfg),
documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
documents_formatter=DocumentsFormatterHTML(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
**buster_cfg.documents_answerer_cfg,
)
Expand Down
Loading

0 comments on commit ba76a84

Please sign in to comment.