Skip to content

Commit

Permalink
better token count management (#92)
Browse files Browse the repository at this point in the history
* Add truncation on token count and not on word count

* decouple system prompt formatting from document formatting

* add check in prompt formatter for token length

* update tests
  • Loading branch information
jerpint authored May 1, 2023
1 parent a5af3f8 commit 5f33f79
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 34 deletions.
24 changes: 21 additions & 3 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

from buster.completers import completer_factory
from buster.completers.base import Completion
from buster.formatters.documents import document_formatter_factory
from buster.formatters.prompts import prompt_formatter_factory
from buster.retriever import Retriever
from buster.tokenizers import tokenizer_factory

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand All @@ -31,15 +33,21 @@ class BusterConfig:
unknown_threshold: float = 0.85
unknown_prompt: str = "I Don't know how to answer your question."
document_source: str = ""
tokenizer_cfg: dict = field(
default_factory=lambda: {
"model_name": "gpt-3.5-turbo",
}
)
retriever_cfg: dict = field(
default_factory=lambda: {
"max_tokens": 3000,
"top_k": 3,
"thresh": 0.7,
}
)
prompt_cfg: dict = field(
default_factory=lambda: {
"max_words": 3000,
"max_tokens": 3500,
"text_before_documents": "You are a chatbot answering questions.\n",
"text_before_prompt": "Answer the following question:\n",
}
Expand Down Expand Up @@ -88,13 +96,20 @@ def update_cfg(self, cfg: BusterConfig):
self.retriever_cfg = cfg.retriever_cfg
self.completion_cfg = cfg.completion_cfg
self.prompt_cfg = cfg.prompt_cfg
self.tokenizer_cfg = cfg.tokenizer_cfg

# set the unk. embedding
self.unk_embedding = self.get_embedding(self.unknown_prompt, engine=self.embedding_model)

# update completer and formatter cfg
self.tokenizer = tokenizer_factory(self.tokenizer_cfg)
self.completer = completer_factory(self.completion_cfg)
self.prompt_formatter = prompt_formatter_factory(self.prompt_cfg)
self.documents_formatter = document_formatter_factory(
tokenizer=self.tokenizer,
max_tokens=self.retriever_cfg["max_tokens"]
# TODO: move max_tokens from retriever_cfg to somewhere more logical
)
self.prompt_formatter = prompt_formatter_factory(tokenizer=self.tokenizer, prompt_cfg=self.prompt_cfg)

logger.info(f"Config Updated.")

Expand Down Expand Up @@ -185,8 +200,11 @@ def process_input(self, user_input: str) -> Response:
)
return response

# format the matched documents, (will truncate them if too long)
documents_str, matched_documents = self.documents_formatter.format(matched_documents)

# prepare the prompt
system_prompt = self.prompt_formatter.format(matched_documents)
system_prompt = self.prompt_formatter.format(documents_str)
completion: Completion = self.completer.generate_response(user_input=user_input, system_prompt=system_prompt)
logger.info(f"GPT Response:\n{completion.text}")

Expand Down
2 changes: 1 addition & 1 deletion buster/completers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def generate_response(self, system_prompt, user_input) -> Completion:
try:
completion = self.complete(prompt=prompt, **self.completion_kwargs)
except openai.error.InvalidRequestError:
logger.exception("Error connecting to OpenAI API. See traceback:")
logger.exception("Invalid request to OpenAI API. See traceback:")
return Completion("Something went wrong, try again soon!", True, "Invalid request made to openai.")
except Exception as e:
# log the error and return a generic response instead.
Expand Down
6 changes: 5 additions & 1 deletion buster/examples/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
retriever_cfg={
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
},
completion_cfg={
"name": "ChatGPT",
"completion_kwargs": {
"model": "gpt-3.5-turbo",
},
},
tokenizer_cfg={
"model_name": "gpt-3.5-turbo",
},
prompt_cfg={
"max_words": 3000,
"max_tokens": 3500,
"text_before_documents": (
"You are a chatbot assistant answering technical questions about artificial intelligence (AI)."
"You can only respond to a question if the content necessary to answer the question is contained in the following provided documentation. "
Expand Down
58 changes: 58 additions & 0 deletions buster/formatters/documents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
from dataclasses import dataclass

import pandas as pd

from buster.tokenizers import Tokenizer

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


@dataclass
class DocumentsFormatter:
tokenizer: Tokenizer
max_tokens: int

def format(
self,
matched_documents: pd.DataFrame,
) -> tuple[str, pd.DataFrame]:
"""Format our matched documents to plaintext.
We also make sure they fit in the alloted max_tokens space.
"""
documents_str = ""
total_tokens = 0
max_tokens = self.max_tokens

num_total_docs = len(matched_documents)
num_preserved_docs = 0
for doc in matched_documents.content.to_list():
num_preserved_docs += 1
token_count, encoded = self.tokenizer.num_tokens(doc, return_encoded=True)
if total_tokens + token_count <= max_tokens:
documents_str += f"<DOCUMENT> {doc} <\\DOCUMENT>"
total_tokens += token_count
else:
logger.warning("truncating document to fit...")
remaining_tokens = max_tokens - total_tokens
truncated_doc = self.tokenizer.decode(encoded[:remaining_tokens])
documents_str += f"<DOCUMENT> {truncated_doc} <\\DOCUMENT>"
logger.warning(f"Documents after truncation: {documents_str}")
break

if num_preserved_docs < (num_total_docs):
logger.warning(
f"{num_preserved_docs}/{num_total_docs} documents were preserved from the matched documents due to truncation."
)
matched_documents = matched_documents.iloc[:num_preserved_docs]

return documents_str, matched_documents


def document_formatter_factory(tokenizer: Tokenizer, max_tokens):
return DocumentsFormatter(
tokenizer=tokenizer,
max_tokens=max_tokens,
)
44 changes: 19 additions & 25 deletions buster/formatters/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,42 @@

import pandas as pd

from buster.tokenizers import Tokenizer

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


@dataclass
class SystemPromptFormatter:
text_before_docs: str = ""
text_after_docs: str = ""
max_words: int = 4000

def format_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
# gather the documents in one large plaintext variable
documents_list = matched_documents.content.to_list()
documents_str = ""
for idx, doc in enumerate(documents_list):
documents_str += f"<DOCUMENT> {doc} <\\DOCUMENT>"

# truncate the documents to fit
# TODO: increase to actual token count
word_count = len(documents_str.split(" "))
if word_count > max_words:
logger.warning("truncating documents to fit...")
documents_str = " ".join(documents_str.split(" ")[0:max_words])
logger.warning(f"Documents after truncation: {documents_str}")

return documents_str
tokenizer: Tokenizer
max_tokens: 3500
text_before_docs: str
text_after_docs: str
formatter: str = "{text_before_docs}\n{documents}\n{text_after_docs}"

def format(
self,
matched_documents: str,
documents: str,
) -> str:
"""
Prepare the system prompt with prompt engineering.
Joins the text before and after documents with
"""
documents = self.format_documents(matched_documents, max_words=self.max_words)
system_prompt = self.text_before_docs + documents + self.text_after_docs
system_prompt = self.formatter.format(
text_before_docs=self.text_before_docs, documents=documents, text_after_docs=self.text_after_docs
)

if self.tokenizer.num_tokens(system_prompt) > self.max_tokens:
raise ValueError(f"System prompt tokens > {self.max_tokens=}")
return system_prompt


def prompt_formatter_factory(prompt_cfg):
def prompt_formatter_factory(tokenizer: Tokenizer, prompt_cfg):
return SystemPromptFormatter(
tokenizer=tokenizer,
max_tokens=prompt_cfg["max_tokens"],
text_before_docs=prompt_cfg["text_before_documents"],
text_after_docs=prompt_cfg["text_before_prompt"],
max_words=prompt_cfg["max_words"],
)
13 changes: 13 additions & 0 deletions buster/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .base import Tokenizer
from .gpt import GPTTokenizer


def tokenizer_factory(tokenizer_cfg: dict) -> Tokenizer:
model_name = tokenizer_cfg["model_name"]
if model_name in ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]:
return GPTTokenizer(model_name)

raise ValueError(f"Tokenizer not implemented for {model_name=}")


__all__ = [Tokenizer, GPTTokenizer, tokenizer_factory]
23 changes: 23 additions & 0 deletions buster/tokenizers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import ABC, abstractmethod
from typing import Union


class Tokenizer(ABC):
"""Abstract base class for a tokenizer."""

def __init__(self, model_name: str):
self.model_name = model_name

@abstractmethod
def encode(self, string: str) -> list[int]:
...

@abstractmethod
def decode(self, encoded: list[int]) -> str:
...

def num_tokens(self, string: str, return_encoded: bool = False) -> Union[int, tuple[int, list[int]]]:
encoded = self.encode(string)
if return_encoded:
return len(encoded), encoded
return len(encoded)
17 changes: 17 additions & 0 deletions buster/tokenizers/gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import tiktoken

from buster.tokenizers import Tokenizer


class GPTTokenizer(Tokenizer):
"""Tokenizer from openai, supports most GPT models."""

def __init__(self, model_name: str):
super().__init__(model_name)
self.encoder = tiktoken.encoding_for_model(model_name=model_name)

def encode(self, string: str):
return self.encoder.encode(string)

def decode(self, encoded: list[int]):
return self.encoder.decode(encoded)
16 changes: 12 additions & 4 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_chatbot_mock_data(tmp_path, monkeypatch):
retriever_cfg={
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
},
document_source="fake source",
completion_cfg={
Expand All @@ -84,7 +85,7 @@ def test_chatbot_mock_data(tmp_path, monkeypatch):
},
},
prompt_cfg={
"max_words": 2000,
"max_tokens": 3500,
"text_before_documents": "",
"text_before_prompt": (
"""You are a slack chatbot assistant answering technical questions about huggingface transformers, a library to train transformers in python.\n"""
Expand Down Expand Up @@ -113,8 +114,13 @@ def test_chatbot_real_data__chatGPT():
"temperature": 0,
},
},
retriever_cfg={
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
},
prompt_cfg={
"max_words": 2000,
"max_tokens": 3500,
"text_before_documents": "",
"text_before_prompt": (
"""You are a slack chatbot assistant answering technical questions about huggingface transformers, a library to train transformers in python.\n"""
Expand Down Expand Up @@ -144,9 +150,10 @@ def test_chatbot_real_data__chatGPT_OOD():
retriever_cfg={
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
},
prompt_cfg={
"max_words": 3000,
"max_tokens": 3500,
"text_before_prompt": (
"""You are a chatbot assistant answering technical questions about huggingface transformers, a library to train transformers in python. """
"""Make sure to format your answers in Markdown format, including code block and snippets. """
Expand Down Expand Up @@ -183,9 +190,10 @@ def test_chatbot_real_data__GPT():
retriever_cfg={
"top_k": 3,
"thresh": 0.7,
"max_tokens": 3000,
},
prompt_cfg={
"max_words": 3000,
"max_tokens": 3500,
"text_before_prompt": (
"""You are a chatbot assistant answering technical questions about huggingface transformers, a library to train transformers in python. """
"""Make sure to format your answers in Markdown format, including code block and snippets. """
Expand Down
Loading

0 comments on commit 5f33f79

Please sign in to comment.