From 59ba24f718ecbd387fac309528472ac63be04bff Mon Sep 17 00:00:00 2001 From: Jeremy Pinto Date: Thu, 1 Jun 2023 14:53:58 -0400 Subject: [PATCH] Remove pickle retriever (#103) * remove PickleDB support * update tests * remove unused functions from utils * allow passing a connection directly to sqlite * remove .base when importing from completers and validators --- buster/busterbot.py | 4 +- buster/completers/__init__.py | 10 ++++- buster/documents/__init__.py | 3 +- buster/documents/pickle.py | 37 ----------------- buster/examples/gradio_app.py | 4 +- buster/retriever/__init__.py | 3 +- buster/retriever/pickle.py | 35 ---------------- buster/retriever/sqlite.py | 23 ++++++++--- buster/utils.py | 28 ------------- buster/validators/__init__.py | 3 ++ buster/validators/base.py | 2 +- tests/test_chatbot.py | 77 ++++++++++++++++++++++------------- tests/test_docparser.py | 12 +++--- tests/test_documents.py | 8 ++-- tests/test_read_write.py | 2 +- tests/test_validator.py | 3 +- 16 files changed, 95 insertions(+), 159 deletions(-) delete mode 100644 buster/documents/pickle.py delete mode 100644 buster/retriever/pickle.py create mode 100644 buster/validators/__init__.py diff --git a/buster/busterbot.py b/buster/busterbot.py index 61742cd..90a80cd 100644 --- a/buster/busterbot.py +++ b/buster/busterbot.py @@ -2,9 +2,9 @@ from dataclasses import dataclass, field from typing import Any -from buster.completers.base import Completer, Completion +from buster.completers import Completer, Completion from buster.retriever import Retriever -from buster.validators.base import Validator +from buster.validators import Validator logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/buster/completers/__init__.py b/buster/completers/__init__.py index 8e39bcc..90865c5 100644 --- a/buster/completers/__init__.py +++ b/buster/completers/__init__.py @@ -1,7 +1,15 @@ -from .base import ChatGPTCompleter, GPT3Completer, completer_factory +from .base import ( + ChatGPTCompleter, + Completer, + Completion, + GPT3Completer, + completer_factory, +) __all__ = [ completer_factory, GPT3Completer, ChatGPTCompleter, + Completer, + Completion, ] diff --git a/buster/documents/__init__.py b/buster/documents/__init__.py index 2f3293c..ba32a13 100644 --- a/buster/documents/__init__.py +++ b/buster/documents/__init__.py @@ -1,6 +1,5 @@ from .base import DocumentsManager -from .pickle import DocumentsPickle from .service import DocumentsService from .sqlite import DocumentsDB -__all__ = [DocumentsManager, DocumentsPickle, DocumentsDB, DocumentsService] +__all__ = [DocumentsManager, DocumentsDB, DocumentsService] diff --git a/buster/documents/pickle.py b/buster/documents/pickle.py deleted file mode 100644 index 2aafa81..0000000 --- a/buster/documents/pickle.py +++ /dev/null @@ -1,37 +0,0 @@ -import os - -import pandas as pd - -from buster.documents.base import DocumentsManager - - -class DocumentsPickle(DocumentsManager): - def __init__(self, filepath: str): - self.filepath = filepath - - if os.path.exists(filepath): - self.documents = pd.read_pickle(filepath) - else: - self.documents = None - - def __repr__(self): - return "DocumentsPickle" - - def add(self, source: str, df: pd.DataFrame): - """Write all documents from the dataframe into the db as a new version.""" - if source is not None: - df["source"] = source - - df["current"] = 1 - - if self.documents is not None: - self.documents.loc[self.documents.source == source, "current"] = 0 - self.documents = pd.concat([self.documents, df]) - else: - self.documents = df - - self.documents.to_pickle(self.filepath) - - def update_source(self, source: str, display_name: str = None, note: str = None): - """Update the display name and/or note of a source. Also create the source if it does not exist.""" - print("If you need this function, please switch your backend to DocumentsDB.") diff --git a/buster/examples/gradio_app.py b/buster/examples/gradio_app.py index a72ad8f..4538673 100644 --- a/buster/examples/gradio_app.py +++ b/buster/examples/gradio_app.py @@ -3,12 +3,12 @@ import pandas as pd from buster.busterbot import Buster -from buster.completers.base import ChatGPTCompleter, Completer +from buster.completers import ChatGPTCompleter, Completer from buster.formatters.documents import DocumentsFormatter from buster.formatters.prompts import PromptFormatter from buster.retriever import Retriever, SQLiteRetriever from buster.tokenizers import GPTTokenizer -from buster.validators.base import Validator +from buster.validators import Validator # initialize buster with the config in cfg.py (adapt to your needs) ... buster_cfg = cfg.buster_cfg diff --git a/buster/retriever/__init__.py b/buster/retriever/__init__.py index 5ec1048..0484c67 100644 --- a/buster/retriever/__init__.py +++ b/buster/retriever/__init__.py @@ -1,6 +1,5 @@ from .base import Retriever -from .pickle import PickleRetriever from .service import ServiceRetriever from .sqlite import SQLiteRetriever -__all__ = [Retriever, PickleRetriever, SQLiteRetriever, ServiceRetriever] +__all__ = [Retriever, SQLiteRetriever, ServiceRetriever] diff --git a/buster/retriever/pickle.py b/buster/retriever/pickle.py deleted file mode 100644 index eeccbe9..0000000 --- a/buster/retriever/pickle.py +++ /dev/null @@ -1,35 +0,0 @@ -import pandas as pd - -from buster.retriever.base import ALL_SOURCES, Retriever - - -class PickleRetriever(Retriever): - def __init__(self, db_path: str, **kwargs): - super().__init__(**kwargs) - self.db_path = db_path - self.documents = pd.read_pickle(db_path) - - def get_documents(self, source: str) -> pd.DataFrame: - """Get all current documents from a given source.""" - if self.documents is None: - raise FileNotFoundError(f"No documents found at {self.db_path}. Are you sure this is the correct path?") - - documents = self.documents.copy() - # The `current` column exists when multiple versions of a document exist - if "current" in documents.columns: - documents = documents[documents.current == 1] - - # Drop the `current` column - documents.drop(columns=["current"], inplace=True) - - if source not in [None, ""] and "source" in documents.columns: - documents = documents[documents.source == source] - - return documents - - def get_source_display_name(self, source: str) -> str: - """Get the display name of a source.""" - if source is None: - return ALL_SOURCES - else: - return source diff --git a/buster/retriever/sqlite.py b/buster/retriever/sqlite.py index fd077c6..fbcec3c 100644 --- a/buster/retriever/sqlite.py +++ b/buster/retriever/sqlite.py @@ -1,3 +1,4 @@ +import os import sqlite3 from pathlib import Path @@ -18,15 +19,25 @@ class SQLiteRetriever(Retriever): >>> df = db.get_documents("source") """ - def __init__(self, **kwargs): + def __init__(self, db_path: str | Path = None, connection: sqlite3.Connection = None, **kwargs): super().__init__(**kwargs) - db_path = kwargs["db_path"] - if isinstance(db_path, (str, Path)): + + match sum([arg is not None for arg in [db_path, connection]]): + # Check that only db_path or connection get specified + case 0: + raise ValueError("At least one of db_path or connection should be specified") + case 2: + raise ValueError("Only one of db_path and connection should be specified.") + + if connection is not None: + self.conn = connection + + if db_path is not None: + if not os.path.exists(db_path): + raise FileNotFoundError(f"{db_path=} specified, but file does not exist") self.db_path = db_path self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False) - else: - self.db_path = None - self.conn = db_path + schema.setup_db(self.conn) def __del__(self): diff --git a/buster/utils.py b/buster/utils.py index 9376cce..85b3aae 100644 --- a/buster/utils.py +++ b/buster/utils.py @@ -1,11 +1,5 @@ import os import urllib.request -from typing import Type - -from buster.documents import DocumentsDB, DocumentsManager, DocumentsPickle -from buster.retriever import PickleRetriever, Retriever, SQLiteRetriever - -PICKLE_EXTENSIONS = [".gz", ".bz2", ".zip", ".xz", ".zst", ".tar", ".tar.gz", ".tar.xz", ".tar.bz2"] def get_file_extension(filepath: str) -> str: @@ -22,25 +16,3 @@ def download_db(db_url: str, output_dir: str): else: print("File already exists. Skipping.") return fname - - -def get_documents_manager_from_extension(filepath: str) -> Type[DocumentsManager]: - ext = get_file_extension(filepath) - - if ext in PICKLE_EXTENSIONS: - return DocumentsPickle - elif ext == ".db": - return DocumentsDB - else: - raise ValueError(f"Unsupported format: {ext}.") - - -def get_retriever_from_extension(filepath: str) -> Type[Retriever]: - ext = get_file_extension(filepath) - - if ext in PICKLE_EXTENSIONS: - return PickleRetriever - elif ext == ".db": - return SQLiteRetriever - else: - raise ValueError(f"Unsupported format: {ext}.") diff --git a/buster/validators/__init__.py b/buster/validators/__init__.py new file mode 100644 index 0000000..808e714 --- /dev/null +++ b/buster/validators/__init__.py @@ -0,0 +1,3 @@ +from .base import Validator + +__all__ = [Validator] diff --git a/buster/validators/base.py b/buster/validators/base.py index 3e073a5..826c37c 100644 --- a/buster/validators/base.py +++ b/buster/validators/base.py @@ -4,7 +4,7 @@ import pandas as pd from openai.embeddings_utils import cosine_similarity, get_embedding -from buster.completers.base import Completion +from buster.completers import Completion logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 17752e2..15c1c74 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -5,22 +5,22 @@ import numpy as np import pandas as pd +import pytest from buster.busterbot import Buster, BusterConfig -from buster.completers.base import ChatGPTCompleter, Completer, Completion +from buster.completers import ChatGPTCompleter, Completer, Completion +from buster.docparser import generate_embeddings +from buster.documents.sqlite.documents import DocumentsDB from buster.formatters.documents import DocumentsFormatter from buster.formatters.prompts import PromptFormatter -from buster.retriever import Retriever -from buster.retriever.pickle import PickleRetriever +from buster.retriever import Retriever, SQLiteRetriever from buster.tokenizers.gpt import GPTTokenizer -from buster.validators.base import Validator +from buster.validators import Validator logging.basicConfig(level=logging.INFO) -TEST_DATA_DIR = Path(__file__).resolve().parent / "data" -DB_PATH = os.path.join(str(TEST_DATA_DIR), "document_embeddings_huggingface_subset.tar.gz") -DB_SOURCE = "huggingface" +DOCUMENTS_CSV = Path(__file__).resolve().parent.parent / "buster/examples/stackoverflow.csv" UNKNOWN_PROMPT = "I'm sorry but I don't know how to answer." # default class used by our tests @@ -38,7 +38,7 @@ "use_reranking": True, }, retriever_cfg={ - "db_path": DB_PATH, + # "db_path": to be set using pytest fixture, "top_k": 3, "thresh": 0.7, "max_tokens": 2000, @@ -46,12 +46,16 @@ }, prompt_formatter_cfg={ "max_tokens": 3500, - "text_before_docs": "", - "text_after_docs": ( - """You are a slack chatbot assistant answering technical questions about huggingface transformers, a library to train transformers in python.\n""" - """Make sure to format your answers in Markdown format, including code block and snippets.\n""" - """Do not include any links to urls or hyperlinks in your answers.\n\n""" - """Now answer the following question:\n""" + "text_after_docs": ("""Now answer the following question:\n"""), + "text_before_docs": ( + """You are a chatbot assistant answering technical questions about artificial intelligence (AI). """ + """If you do not know the answer to a question, or if it is completely irrelevant to your domain knowledge of AI library usage, let the user know you cannot answer.""" + """Use this response when you cannot answer:\n""" + f"""'{UNKNOWN_PROMPT}'\n""" + """For example:\n""" + """What is the meaning of life?\n""" + f"""'{UNKNOWN_PROMPT}'\n""" + """Only use these prodived documents as reference:\n""" ), }, documents_formatter_cfg={ @@ -117,6 +121,21 @@ def validate(self, completion): return completion +@pytest.fixture(scope="session") +def database_file(tmp_path_factory): + # Create a temporary directory and file for the database + db_file = tmp_path_factory.mktemp("data").joinpath("documents.db") + + # Generate the actual embeddings + documents_manager = DocumentsDB(db_file) + documents = pd.read_csv(DOCUMENTS_CSV) + documents = generate_embeddings(documents, documents_manager) + yield db_file + + # Teardown: Remove the temporary database file + db_file.unlink() + + def test_chatbot_mock_data(tmp_path, monkeypatch): gpt_expected_answer = "this is GPT answer" @@ -137,10 +156,11 @@ def test_chatbot_mock_data(tmp_path, monkeypatch): assert completion.text.startswith(gpt_expected_answer) -def test_chatbot_real_data__chatGPT(): +def test_chatbot_real_data__chatGPT(database_file): buster_cfg = copy.deepcopy(buster_cfg_template) + buster_cfg.retriever_cfg["db_path"] = database_file - retriever: Retriever = PickleRetriever(**buster_cfg.retriever_cfg) + retriever: Retriever = SQLiteRetriever(**buster_cfg.retriever_cfg) tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg) completer: Completer = ChatGPTCompleter( documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg), @@ -150,33 +170,32 @@ def test_chatbot_real_data__chatGPT(): validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator) - completion = buster.process_input("What is a transformer?", source=DB_SOURCE) + completion = buster.process_input("What is backpropagation?") completion = buster.postprocess_completion(completion) assert isinstance(completion.text, str) assert completion.answer_relevant == True -def test_chatbot_real_data__chatGPT_OOD(): +def test_chatbot_real_data__chatGPT_OOD(database_file): buster_cfg = copy.deepcopy(buster_cfg_template) + buster_cfg.retriever_cfg["db_path"] = database_file buster_cfg.prompt_formatter_cfg = { "max_tokens": 3500, "text_before_docs": ( - """You are a chatbot assistant answering technical questions about huggingface transformers, a library to train transformers in python. """ - """Make sure to format your answers in Markdown format, including code block and snippets. """ - """Do not include any links to urls or hyperlinks in your answers. """ - """If you do not know the answer to a question, or if it is completely irrelevant to the library usage, let the user know you cannot answer. """ + """You are a chatbot assistant answering technical questions about artificial intelligence (AI).""" + """If you do not know the answer to a question, or if it is completely irrelevant to your domain knowledge of AI library usage, let the user know you cannot answer.""" """Use this response: """ f"""'{UNKNOWN_PROMPT}'\n""" """For example:\n""" - """What is the meaning of life for huggingface?\n""" + """What is the meaning of life?\n""" f"""'{UNKNOWN_PROMPT}'\n""" """Now answer the following question:\n""" ), "text_after_docs": "Only use these documents as reference:\n", } - retriever: Retriever = PickleRetriever(**buster_cfg.retriever_cfg) + retriever: Retriever = SQLiteRetriever(**buster_cfg.retriever_cfg) tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg) completer: Completer = ChatGPTCompleter( documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg), @@ -186,24 +205,24 @@ def test_chatbot_real_data__chatGPT_OOD(): validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator) - completion = buster.process_input("What is a good recipe for brocolli soup?", source=DB_SOURCE) + completion = buster.process_input("What is a good recipe for brocolli soup?") completion = buster.postprocess_completion(completion) assert isinstance(completion.text, str) assert completion.answer_relevant == False -def test_chatbot_real_data__no_docs_found(): +def test_chatbot_real_data__no_docs_found(database_file): buster_cfg = copy.deepcopy(buster_cfg_template) buster_cfg.retriever_cfg = { - "db_path": DB_PATH, + "db_path": database_file, "embedding_model": "text-embedding-ada-002", "top_k": 3, "thresh": 1, # Set threshold very high to be sure no docs are matched "max_tokens": 3000, } buster_cfg.completion_cfg["no_documents_message"] = "No documents available." - retriever: Retriever = PickleRetriever(**buster_cfg.retriever_cfg) + retriever: Retriever = SQLiteRetriever(**buster_cfg.retriever_cfg) tokenizer = GPTTokenizer(**buster_cfg.tokenizer_cfg) completer: Completer = ChatGPTCompleter( documents_formatter=DocumentsFormatter(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg), @@ -213,7 +232,7 @@ def test_chatbot_real_data__no_docs_found(): validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, completer=completer, validator=validator) - completion = buster.process_input("What is a transformer?", source=DB_SOURCE) + completion = buster.process_input("What is backpropagation?") completion = buster.postprocess_completion(completion) assert isinstance(completion.text, str) diff --git a/tests/test_docparser.py b/tests/test_docparser.py index 16a417a..380e72b 100644 --- a/tests/test_docparser.py +++ b/tests/test_docparser.py @@ -3,13 +3,11 @@ import pytest from buster.docparser import generate_embeddings -from buster.utils import ( - get_documents_manager_from_extension, - get_retriever_from_extension, -) +from buster.documents import DocumentsDB +from buster.retriever.sqlite import SQLiteRetriever -@pytest.mark.parametrize("extension", ["db", "tar.gz"]) +@pytest.mark.parametrize("extension", ["db"]) def test_generate_embeddings(tmp_path, monkeypatch, extension): # Create fake data data = pd.DataFrame.from_dict( @@ -22,7 +20,7 @@ def test_generate_embeddings(tmp_path, monkeypatch, extension): # Generate embeddings, store in a file output_file = tmp_path / f"test_document_embeddings.{extension}" - manager = get_documents_manager_from_extension(output_file)(output_file) + manager = DocumentsDB(output_file) df = generate_embeddings(data, manager) # Read the embeddings from the file @@ -34,7 +32,7 @@ def test_generate_embeddings(tmp_path, monkeypatch, extension): "max_tokens": 3000, "embedding_model": "text-embedding-ada-002", } - read_df = get_retriever_from_extension(output_file)(**retriever_cfg).get_documents("my_source") + read_df = SQLiteRetriever(**retriever_cfg).get_documents("my_source") # Check all the values are correct across the files assert df["title"].iloc[0] == data["title"].iloc[0] == read_df["title"].iloc[0] diff --git a/tests/test_documents.py b/tests/test_documents.py index 4159760..cbd447d 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -2,13 +2,13 @@ import pandas as pd import pytest -from buster.documents import DocumentsDB, DocumentsPickle -from buster.retriever import PickleRetriever, SQLiteRetriever +from buster.documents import DocumentsDB +from buster.retriever import SQLiteRetriever @pytest.mark.parametrize( "documents_manager, retriever, extension", - [(DocumentsDB, SQLiteRetriever, "db"), (DocumentsPickle, PickleRetriever, "tar.gz")], + [(DocumentsDB, SQLiteRetriever, "db")], ) def test_write_read(tmp_path, documents_manager, retriever, extension): db_path = tmp_path / f"test.{extension}" @@ -43,7 +43,7 @@ def test_write_read(tmp_path, documents_manager, retriever, extension): @pytest.mark.parametrize( "documents_manager, retriever, extension", - [(DocumentsDB, SQLiteRetriever, "db"), (DocumentsPickle, PickleRetriever, "tar.gz")], + [(DocumentsDB, SQLiteRetriever, "db")], ) def test_write_write_read(tmp_path, documents_manager, retriever, extension): db_path = tmp_path / f"test.{extension}" diff --git a/tests/test_read_write.py b/tests/test_read_write.py index c2297b2..2e1a3ab 100644 --- a/tests/test_read_write.py +++ b/tests/test_read_write.py @@ -1,6 +1,6 @@ import pandas as pd -from buster.completers.base import Completion +from buster.completers import Completion class MockValidator: diff --git a/tests/test_validator.py b/tests/test_validator.py index e0439bf..2eae2ab 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -1,8 +1,7 @@ import pandas as pd from openai.embeddings_utils import get_embedding -from buster.completers.base import Completion -from buster.validators.base import Validator +from buster.validators import Validator def test_validator_check_answer_relevance():