Skip to content

Commit

Permalink
Remove pickle retriever (#103)
Browse files Browse the repository at this point in the history
* remove PickleDB support

* update tests

* remove unused functions from utils

* allow passing a connection directly to sqlite

* remove .base when importing from completers and validators
  • Loading branch information
jerpint authored Jun 1, 2023
1 parent 8a5dc6c commit 59ba24f
Show file tree
Hide file tree
Showing 16 changed files with 95 additions and 159 deletions.
4 changes: 2 additions & 2 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from dataclasses import dataclass, field
from typing import Any

from buster.completers.base import Completer, Completion
from buster.completers import Completer, Completion
from buster.retriever import Retriever
from buster.validators.base import Validator
from buster.validators import Validator

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down
10 changes: 9 additions & 1 deletion buster/completers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from .base import ChatGPTCompleter, GPT3Completer, completer_factory
from .base import (
ChatGPTCompleter,
Completer,
Completion,
GPT3Completer,
completer_factory,
)

__all__ = [
completer_factory,
GPT3Completer,
ChatGPTCompleter,
Completer,
Completion,
]
3 changes: 1 addition & 2 deletions buster/documents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .base import DocumentsManager
from .pickle import DocumentsPickle
from .service import DocumentsService
from .sqlite import DocumentsDB

__all__ = [DocumentsManager, DocumentsPickle, DocumentsDB, DocumentsService]
__all__ = [DocumentsManager, DocumentsDB, DocumentsService]
37 changes: 0 additions & 37 deletions buster/documents/pickle.py

This file was deleted.

4 changes: 2 additions & 2 deletions buster/examples/gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import pandas as pd

from buster.busterbot import Buster
from buster.completers.base import ChatGPTCompleter, Completer
from buster.completers import ChatGPTCompleter, Completer
from buster.formatters.documents import DocumentsFormatter
from buster.formatters.prompts import PromptFormatter
from buster.retriever import Retriever, SQLiteRetriever
from buster.tokenizers import GPTTokenizer
from buster.validators.base import Validator
from buster.validators import Validator

# initialize buster with the config in cfg.py (adapt to your needs) ...
buster_cfg = cfg.buster_cfg
Expand Down
3 changes: 1 addition & 2 deletions buster/retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .base import Retriever
from .pickle import PickleRetriever
from .service import ServiceRetriever
from .sqlite import SQLiteRetriever

__all__ = [Retriever, PickleRetriever, SQLiteRetriever, ServiceRetriever]
__all__ = [Retriever, SQLiteRetriever, ServiceRetriever]
35 changes: 0 additions & 35 deletions buster/retriever/pickle.py

This file was deleted.

23 changes: 17 additions & 6 deletions buster/retriever/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sqlite3
from pathlib import Path

Expand All @@ -18,15 +19,25 @@ class SQLiteRetriever(Retriever):
>>> df = db.get_documents("source")
"""

def __init__(self, **kwargs):
def __init__(self, db_path: str | Path = None, connection: sqlite3.Connection = None, **kwargs):
super().__init__(**kwargs)
db_path = kwargs["db_path"]
if isinstance(db_path, (str, Path)):

match sum([arg is not None for arg in [db_path, connection]]):
# Check that only db_path or connection get specified
case 0:
raise ValueError("At least one of db_path or connection should be specified")
case 2:
raise ValueError("Only one of db_path and connection should be specified.")

if connection is not None:
self.conn = connection

if db_path is not None:
if not os.path.exists(db_path):
raise FileNotFoundError(f"{db_path=} specified, but file does not exist")
self.db_path = db_path
self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False)
else:
self.db_path = None
self.conn = db_path

schema.setup_db(self.conn)

def __del__(self):
Expand Down
28 changes: 0 additions & 28 deletions buster/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import os
import urllib.request
from typing import Type

from buster.documents import DocumentsDB, DocumentsManager, DocumentsPickle
from buster.retriever import PickleRetriever, Retriever, SQLiteRetriever

PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"]


def get_file_extension(filepath: str) -> str:
Expand All @@ -22,25 +16,3 @@ def download_db(db_url: str, output_dir: str):
else:
print("File already exists. Skipping.")
return fname


def get_documents_manager_from_extension(filepath: str) -> Type[DocumentsManager]:
ext = get_file_extension(filepath)

if ext in PICKLE_EXTENSIONS:
return DocumentsPickle
elif ext == ".db":
return DocumentsDB
else:
raise ValueError(f"Unsupported format: {ext}.")


def get_retriever_from_extension(filepath: str) -> Type[Retriever]:
ext = get_file_extension(filepath)

if ext in PICKLE_EXTENSIONS:
return PickleRetriever
elif ext == ".db":
return SQLiteRetriever
else:
raise ValueError(f"Unsupported format: {ext}.")
3 changes: 3 additions & 0 deletions buster/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import Validator

__all__ = [Validator]
2 changes: 1 addition & 1 deletion buster/validators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from openai.embeddings_utils import cosine_similarity, get_embedding

from buster.completers.base import Completion
from buster.completers import Completion

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down
77 changes: 48 additions & 29 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@

import numpy as np
import pandas as pd
import pytest

from buster.busterbot import Buster, BusterConfig
from buster.completers.base import ChatGPTCompleter, Completer, Completion
from buster.completers import ChatGPTCompleter, Completer, Completion
from buster.docparser import generate_embeddings
from buster.documents.sqlite.documents import DocumentsDB
from buster.formatters.documents import DocumentsFormatter
from buster.formatters.prompts import PromptFormatter
from buster.retriever import Retriever
from buster.retriever.pickle import PickleRetriever
from buster.retriever import Retriever, SQLiteRetriever
from buster.tokenizers.gpt import GPTTokenizer
from buster.validators.base import Validator
from buster.validators import Validator

logging.basicConfig(level=logging.INFO)


TEST_DATA_DIR = Path(__file__).resolve().parent / "data"
DB_PATH = os.path.join(str(TEST_DATA_DIR), "document_embeddings_huggingface_subset.tar.gz")
DB_SOURCE = "huggingface"
DOCUMENTS_CSV = Path(__file__).resolve().parent.parent / "buster/examples/stackoverflow.csv"
UNKNOWN_PROMPT = "I'm sorry but I don't know how to answer."

# default class used by our tests
Expand All @@ -38,20 +38,24 @@
"use_reranking": True,
},
retriever_cfg={
"db_path": DB_PATH,
# "db_path": to be set using pytest fixture,
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
},
prompt_formatter_cfg={
"max_tokens": 3500,
"text_before_docs": "",
"text_after_docs": (
"""You are a slack chatbot assistant answering technical questions about huggingface transformers, a library to train transformers in python.\n"""
"""Make sure to format your answers in Markdown format, including code block and snippets.\n"""
"""Do not include any links to urls or hyperlinks in your answers.\n\n"""
"""Now answer the following question:\n"""
"text_after_docs": ("""Now answer the following question:\n"""),
"text_before_docs": (
"""You are a chatbot assistant answering technical questions about artificial intelligence (AI). """
"""If you do not know the answer to a question, or if it is completely irrelevant to your domain knowledge of AI library usage, let the user know you cannot answer."""
"""Use this response when you cannot answer:\n"""
f"""'{UNKNOWN_PROMPT}'\n"""
"""For example:\n"""
"""What is the meaning of life?\n"""
f"""'{UNKNOWN_PROMPT}'\n"""
"""Only use these prodived documents as reference:\n"""
),
},
documents_formatter_cfg={
Expand Down Expand Up @@ -117,6 +121,21 @@ def validate(self, completion):
return completion


@pytest.fixture(scope="session")
def database_file(tmp_path_factory):
# Create a temporary directory and file for the database
db_file = tmp_path_factory.mktemp("data").joinpath("documents.db")

# Generate the actual embeddings
documents_manager = DocumentsDB(db_file)
documents = pd.read_csv(DOCUMENTS_CSV)
documents = generate_embeddings(documents, documents_manager)
yield db_file

# Teardown: Remove the temporary database file
db_file.unlink()


def test_chatbot_mock_data(tmp_path, monkeypatch):
gpt_expected_answer = "this is GPT answer"

Expand All @@ -137,10 +156,11 @@ def test_chatbot_mock_data(tmp_path, monkeypatch):
assert completion.text.startswith(gpt_expected_answer)


def test_chatbot_real_data__chatGPT():
def test_chatbot_real_data__chatGPT(database_file):
buster_cfg = copy.deepcopy(buster_cfg_template)
buster_cfg.retriever_cfg["db_path"] = database_file

retriever: Retriever = PickleRetriever(**buster_cfg.retriever_cfg)
retriever: Retriever = SQLiteRetriever(**buster_cfg.retriever_cfg)
tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg)
completer: Completer = ChatGPTCompleter(
documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
Expand All @@ -150,33 +170,32 @@ def test_chatbot_real_data__chatGPT():
validator: Validator = Validator(**buster_cfg.validator_cfg)
buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator)

completion = buster.process_input("What is a transformer?", source=DB_SOURCE)
completion = buster.process_input("What is backpropagation?")
completion = buster.postprocess_completion(completion)
assert isinstance(completion.text, str)

assert completion.answer_relevant == True


def test_chatbot_real_data__chatGPT_OOD():
def test_chatbot_real_data__chatGPT_OOD(database_file):
buster_cfg = copy.deepcopy(buster_cfg_template)
buster_cfg.retriever_cfg["db_path"] = database_file
buster_cfg.prompt_formatter_cfg = {
"max_tokens": 3500,
"text_before_docs": (
"""You are a chatbot assistant answering technical questions about huggingface transformers, a library to train transformers in python. """
"""Make sure to format your answers in Markdown format, including code block and snippets. """
"""Do not include any links to urls or hyperlinks in your answers. """
"""If you do not know the answer to a question, or if it is completely irrelevant to the library usage, let the user know you cannot answer. """
"""You are a chatbot assistant answering technical questions about artificial intelligence (AI)."""
"""If you do not know the answer to a question, or if it is completely irrelevant to your domain knowledge of AI library usage, let the user know you cannot answer."""
"""Use this response: """
f"""'{UNKNOWN_PROMPT}'\n"""
"""For example:\n"""
"""What is the meaning of life for huggingface?\n"""
"""What is the meaning of life?\n"""
f"""'{UNKNOWN_PROMPT}'\n"""
"""Now answer the following question:\n"""
),
"text_after_docs": "Only use these documents as reference:\n",
}

retriever: Retriever = PickleRetriever(**buster_cfg.retriever_cfg)
retriever: Retriever = SQLiteRetriever(**buster_cfg.retriever_cfg)
tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg)
completer: Completer = ChatGPTCompleter(
documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
Expand All @@ -186,24 +205,24 @@ def test_chatbot_real_data__chatGPT_OOD():
validator: Validator = Validator(**buster_cfg.validator_cfg)
buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator)

completion = buster.process_input("What is a good recipe for brocolli soup?", source=DB_SOURCE)
completion = buster.process_input("What is a good recipe for brocolli soup?")
completion = buster.postprocess_completion(completion)
assert isinstance(completion.text, str)

assert completion.answer_relevant == False


def test_chatbot_real_data__no_docs_found():
def test_chatbot_real_data__no_docs_found(database_file):
buster_cfg = copy.deepcopy(buster_cfg_template)
buster_cfg.retriever_cfg = {
"db_path": DB_PATH,
"db_path": database_file,
"embedding_model": "text-embedding-ada-002",
"top_k": 3,
"thresh": 1, # Set threshold very high to be sure no docs are matched
"max_tokens": 3000,
}
buster_cfg.completion_cfg["no_documents_message"] = "No documents available."
retriever: Retriever = PickleRetriever(**buster_cfg.retriever_cfg)
retriever: Retriever = SQLiteRetriever(**buster_cfg.retriever_cfg)
tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg)
completer: Completer = ChatGPTCompleter(
documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
Expand All @@ -213,7 +232,7 @@ def test_chatbot_real_data__no_docs_found():
validator: Validator = Validator(**buster_cfg.validator_cfg)
buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator)

completion = buster.process_input("What is a transformer?", source=DB_SOURCE)
completion = buster.process_input("What is backpropagation?")
completion = buster.postprocess_completion(completion)
assert isinstance(completion.text, str)

Expand Down
Loading

0 comments on commit 59ba24f

Please sign in to comment.