Skip to content

Commit

Permalink
PR: adding service db (#85)
Browse files Browse the repository at this point in the history
* working services

* black

* tests

* CR
  • Loading branch information
hbertrand authored Apr 19, 2023
1 parent 7c2b9bd commit 17e3538
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 13 deletions.
17 changes: 9 additions & 8 deletions buster/docparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from bs4 import BeautifulSoup
from openai.embeddings_utils import get_embedding

from buster.documents import DocumentsManager
from buster.parser import HuggingfaceParser, Parser, SphinxParser
from buster.utils import get_documents_manager_from_extension

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -148,23 +148,24 @@ def generate_embeddings_parser(root_dir: str, output_filepath: str, source: str)
return generate_embeddings(documents, output_filepath)


def documents_to_db(documents: pd.DataFrame, output_filepath: str):
def documents_to_db(
documents: pd.DataFrame,
documents_manager: DocumentsManager,
):
logger.info("Preparing database...")
documents_manager = get_documents_manager_from_extension(output_filepath)(output_filepath)
sources = documents["source"].unique()
for source in sources:
documents_manager.add(source, documents)
logger.info(f"Documents saved to: {output_filepath}")
logger.info(f"Documents saved to documents manager: {documents_manager}")


def update_source(source: str, output_filepath: str, display_name: str = None, note: str = None):
documents_manager = get_documents_manager_from_extension(output_filepath)(output_filepath)
def update_source(source: str, documents_manager: DocumentsManager, display_name: str = None, note: str = None):
documents_manager.update_source(source, display_name, note)


def generate_embeddings(
documents: pd.DataFrame,
output_filepath: str = "documents.db",
documents_manager: DocumentsManager,
max_words=500,
embedding_engine: str = EMBEDDING_MODEL,
) -> pd.DataFrame:
Expand All @@ -180,7 +181,7 @@ def generate_embeddings(
documents = compute_embeddings(documents, engine=embedding_engine)

# save the documents to a db for later use
documents_to_db(documents, output_filepath)
documents_to_db(documents, documents_manager)

return documents

Expand Down
3 changes: 2 additions & 1 deletion buster/documents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import DocumentsManager
from .pickle import DocumentsPickle
from .service import DocumentsService
from .sqlite import DocumentsDB

__all__ = [DocumentsManager, DocumentsPickle, DocumentsDB]
__all__ = [DocumentsManager, DocumentsPickle, DocumentsDB, DocumentsService]
3 changes: 3 additions & 0 deletions buster/documents/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def __init__(self, filepath: str):
else:
self.documents = None

def __repr__(self):
return "DocumentsPickle"

def add(self, source: str, df: pd.DataFrame):
"""Write all documents from the dataframe into the db as a new version."""
if source is not None:
Expand Down
75 changes: 75 additions & 0 deletions buster/documents/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os

import pandas as pd
import pinecone
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

from buster.documents.base import DocumentsManager


class DocumentsService(DocumentsManager):
"""Manager to use in production. Mixed Pinecone and MongoDB backend."""

def __init__(
self,
pinecone_api_key: str,
pinecone_env: str,
pinecone_index: str,
mongo_uri: str,
mongo_db_name: str,
**kwargs,
):
super().__init__(**kwargs)

pinecone.init(api_key=pinecone_api_key, environment=pinecone_env)

self.index = pinecone.Index(pinecone_index)

self.client = MongoClient(mongo_uri, server_api=ServerApi("1"))
self.db = self.client[mongo_db_name]

def __repr__(self):
return "DocumentsService"

def get_source_id(self, source: str) -> str:
"""Get the id of a source."""
return str(self.db.sources.find_one({"name": source})["_id"])

def add(self, source: str, df: pd.DataFrame):
"""Write all documents from the dataframe into the db as a new version."""
source_exists = self.db.sources.find_one({"name": source})
if source_exists is None:
self.db.sources.insert_one({"name": source})

source_id = self.get_source_id(source)

for _, row in df.iterrows():
document = {
"title": row["title"],
"url": row["url"],
"content": row["content"],
"n_tokens": row["n_tokens"],
"source_id": source_id,
}
document_id = str(self.db.documents.insert_one(document).inserted_id)
self.index.upsert([(document_id, row["embedding"].tolist(), {"source": source})])

def update_source(self, source: str, display_name: str = None, note: str = None):
"""Update the display name and/or note of a source. Also create the source if it does not exist."""
self.db.sources.update_one(
{"name": source}, {"$set": {"display_name": display_name, "note": note}}, upsert=True
)

def delete_source(self, source: str) -> tuple[int, int]:
"""Delete a source and all its documents. Return if the source was deleted and the number of deleted documents."""
source_id = self.get_source_id(source)

# MongoDB
source_deleted = self.db.sources.delete_one({"name": source}).deleted_count
documents_deleted = self.db.documents.delete_many({"source_id": source_id}).deleted_count

# Pinecone
self.index.delete(filter={"source": source})

return source_deleted, documents_deleted
5 changes: 4 additions & 1 deletion buster/documents/sqlite/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class DocumentsDB(DocumentsManager):
def __init__(self, db_path: sqlite3.Connection | str):
if isinstance(db_path, (str, Path)):
self.db_path = db_path
self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False)
self.conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES)
else:
self.db_path = None
self.conn = db_path
Expand All @@ -49,6 +49,9 @@ def __del__(self):
if self.db_path is not None:
self.conn.close()

def __repr__(self):
return f"DocumentsDB({self.db_path})"

def get_current_version(self, source: str) -> tuple[int, int]:
"""Get the current version of a source."""
cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
Expand Down
3 changes: 2 additions & 1 deletion buster/retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import Retriever
from .pickle import PickleRetriever
from .service import ServiceRetriever
from .sqlite import SQLiteRetriever

__all__ = [Retriever, PickleRetriever, SQLiteRetriever]
__all__ = [Retriever, PickleRetriever, SQLiteRetriever, ServiceRetriever]
55 changes: 55 additions & 0 deletions buster/retriever/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pandas as pd
import pinecone
from bson.objectid import ObjectId
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

from buster.retriever.base import ALL_SOURCES, Retriever


class ServiceRetriever(Retriever):
def __init__(
self,
pinecone_api_key: str,
pinecone_env: str,
pinecone_index: str,
mongo_uri: str,
mongo_db_name: str,
**kwargs,
):
super().__init__(**kwargs)

pinecone.init(api_key=pinecone_api_key, environment=pinecone_env)

self.index = pinecone.Index(pinecone_index)

self.client = MongoClient(mongo_uri, server_api=ServerApi("1"))
self.db = self.client[mongo_db_name]

def get_documents(self, source: str) -> pd.DataFrame:
"""Get all current documents from a given source."""
return self.db.documents.find({"source_id": source})

def get_source_display_name(self, source: str) -> str:
"""Get the display name of a source."""
if source == "":
return ALL_SOURCES
else:
display_name = self.db.sources.find_one({"name": source})["display_name"]
return display_name

def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame:
# Pinecone retrieval
matches = self.index.query(query_embedding, top_k=top_k, filter={"source": {"$eq": source}})["matches"]
matching_ids = [ObjectId(match.id) for match in matches]
matching_scores = {match.id: match.score for match in matches}

if len(matching_ids) == 0:
return pd.DataFrame()

# MongoDB retrieval
matched_documents = self.db.documents.find({"_id": {"$in": matching_ids}})
matched_documents = pd.DataFrame(list(matched_documents))
matched_documents["similarity"] = matched_documents["_id"].apply(lambda x: matching_scores[str(x)])

return matched_documents
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ promptlayer
pytest
openai
click
pinecone-client
pymongo

# all openai[embeddings] deps, their list breaks our CI, see: https://github.com/openai/openai-python/issues/210

Expand Down
8 changes: 6 additions & 2 deletions tests/test_docparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import pytest

from buster.docparser import generate_embeddings
from buster.utils import get_retriever_from_extension
from buster.utils import (
get_documents_manager_from_extension,
get_retriever_from_extension,
)


@pytest.mark.parametrize("extension", ["db", "tar.gz"])
Expand All @@ -19,7 +22,8 @@ def test_generate_embeddings(tmp_path, monkeypatch, extension):

# Generate embeddings, store in a file
output_file = tmp_path / f"test_document_embeddings.{extension}"
df = generate_embeddings(data, output_file)
manager = get_documents_manager_from_extension(output_file)(output_file)
df = generate_embeddings(data, manager)

# Read the embeddings from the file
read_df = get_retriever_from_extension(output_file)(output_file).get_documents("my_source")
Expand Down

0 comments on commit 17e3538

Please sign in to comment.