Skip to content

Commit

Permalink
feat:chroma store refactor (eosphoros-ai#1508)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt authored May 11, 2024
1 parent bc9ce3c commit d313155
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 40 deletions.
2 changes: 1 addition & 1 deletion dbgpt/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.5.5"
version = "0.5.6"
12 changes: 7 additions & 5 deletions dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,13 @@ def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest)
doc_ids = sync_request.doc_ids
self.model_name = sync_request.model_name or CFG.LLM_MODEL
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
id=doc_id,
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
query = KnowledgeDocumentEntity(id=doc_id)
docs = knowledge_document_dao.get_documents(query)
if len(docs) == 0:
raise Exception(
f"there are document called, doc_id: {sync_request.doc_id}"
)
doc = docs[0]
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
Expand Down
30 changes: 30 additions & 0 deletions dbgpt/storage/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,36 @@ def load_document_with_limit(
)
return ids

def filter_by_score_threshold(
self, chunks: List[Chunk], score_threshold: float
) -> List[Chunk]:
"""Filter chunks by score threshold.
Args:
chunks(List[Chunks]): The chunks to filter.
score_threshold(float): The score threshold.
Return:
List[Chunks]: The filtered chunks.
"""
candidates_chunks = chunks
if score_threshold is not None:
candidates_chunks = [
Chunk(
metadata=chunk.metadata,
content=chunk.content,
score=chunk.score,
chunk_id=str(id),
)
for chunk in chunks
if chunk.score >= score_threshold
]
if len(candidates_chunks) == 0:
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
return candidates_chunks

@abstractmethod
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
Expand Down
150 changes: 119 additions & 31 deletions dbgpt/storage/vector_store/chroma_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Chroma vector store."""
import logging
import os
from typing import List, Optional
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union

from chromadb import PersistentClient
from chromadb.config import Settings
Expand All @@ -17,6 +17,7 @@

logger = logging.getLogger(__name__)

CHROMA_COLLECTION_NAME = "langchain"

@register_resource(
_("Chroma Vector Store"),
Expand Down Expand Up @@ -55,9 +56,11 @@ class ChromaStore(VectorStoreBase):
"""Chroma vector store."""

def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
"""Create a ChromaStore instance."""
from langchain.vectorstores import Chroma
"""Create a ChromaStore instance.
Args:
vector_store_config(ChromaVectorConfig): vector store config.
"""
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
chroma_path = chroma_vector_config.get(
"persist_path", os.path.join(PILOT_PATH, "data")
Expand All @@ -71,31 +74,35 @@ def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
persist_directory=self.persist_dir,
anonymized_telemetry=False,
)
client = PersistentClient(path=self.persist_dir, settings=chroma_settings)
self._chroma_client = PersistentClient(
path=self.persist_dir, settings=chroma_settings
)

collection_metadata = chroma_vector_config.get("collection_metadata") or {
"hnsw:space": "cosine"
}
self.vector_store_client = Chroma(
persist_directory=self.persist_dir,
embedding_function=self.embeddings,
# client_settings=chroma_settings,
client=client,
collection_metadata=collection_metadata,
) # type: ignore
self._collection = self._chroma_client.get_or_create_collection(
name=CHROMA_COLLECTION_NAME,
embedding_function=None,
metadata=collection_metadata,
)

def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Search similar documents."""
logger.info("ChromaStore similar search")
where_filters = self.convert_metadata_filters(filters) if filters else None
lc_documents = self.vector_store_client.similarity_search(
text, topk, filter=where_filters
chroma_results = self._query(
text=text,
topk=topk,
filters=filters,
)
return [
Chunk(content=doc.page_content, metadata=doc.metadata)
for doc in lc_documents
Chunk(content=chroma_result[0], metadata=chroma_result[1] or {}, score=0.0)
for chroma_result in zip(
chroma_results["documents"][0],
chroma_results["metadatas"][0],
)
]

def similar_search_with_scores(
Expand All @@ -114,19 +121,26 @@ def similar_search_with_scores(
filters(MetadataFilters): metadata filters, defaults to None
"""
logger.info("ChromaStore similar search with scores")
where_filters = self.convert_metadata_filters(filters) if filters else None
docs_and_scores = (
self.vector_store_client.similarity_search_with_relevance_scores(
query=text,
k=topk,
score_threshold=score_threshold,
filter=where_filters,
)
chroma_results = self._query(
text=text,
topk=topk,
filters=filters,
)
return [
Chunk(content=doc.page_content, metadata=doc.metadata, score=score)
for doc, score in docs_and_scores
chunks = [
(
Chunk(
content=chroma_result[0],
metadata=chroma_result[1] or {},
score=chroma_result[2],
)
)
for chroma_result in zip(
chroma_results["documents"][0],
chroma_results["metadatas"][0],
chroma_results["distances"][0],
)
]
return self.filter_by_score_threshold(chunks, score_threshold)

def vector_name_exists(self) -> bool:
"""Whether vector name exists."""
Expand All @@ -138,19 +152,24 @@ def vector_name_exists(self) -> bool:
files = list(filter(lambda f: f != "chroma.sqlite3", files))
return len(files) > 0


def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document to vector store."""
logger.info("ChromaStore load document")
texts = [chunk.content for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
ids = [chunk.chunk_id for chunk in chunks]
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas, ids=ids)
chroma_metadatas = [
_transform_chroma_metadata(metadata) for metadata in metadatas
]
self._add_texts(texts=texts, metadatas=chroma_metadatas, ids=ids)
return ids

def delete_vector_name(self, vector_name: str):
"""Delete vector name."""
logger.info(f"chroma vector_name:{vector_name} begin delete...")
self.vector_store_client.delete_collection()
# self.vector_store_client.delete_collection()
self._chroma_client.delete_collection(self._collection.name)
self._clean_persist_folder()
return True

Expand All @@ -159,8 +178,7 @@ def delete_by_ids(self, ids):
logger.info(f"begin delete chroma ids: {ids}")
ids = ids.split(",")
if len(ids) > 0:
collection = self.vector_store_client._collection
collection.delete(ids=ids)
self._collection.delete(ids=ids)

def convert_metadata_filters(
self,
Expand Down Expand Up @@ -198,6 +216,65 @@ def convert_metadata_filters(
where_filters[chroma_condition] = filters_list
return where_filters

def _add_texts(
self,
texts: Iterable[str],
ids: List[str],
metadatas: Optional[List[Mapping[str, Union[str, int, float, bool]]]] = None,
) -> List[str]:
"""Add texts to Chroma collection.
Args:
texts(Iterable[str]): texts.
metadatas(Optional[List[dict]]): metadatas.
ids(Optional[List[str]]): ids.
Returns:
List[str]: ids.
"""
embeddings = None
texts = list(texts)
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(texts)
if metadatas:
try:
self._collection.upsert(
metadatas=metadatas,
embeddings=embeddings, # type: ignore
documents=texts,
ids=ids,
)
except ValueError as e:
logger.error(f"Error upsert chromadb with metadata: {e}")
else:
self._collection.upsert(
embeddings=embeddings, # type: ignore
documents=texts,
ids=ids,
)
return ids

def _query(self, text: str, topk: int, filters: Optional[MetadataFilters] = None):
"""Query Chroma collection.
Args:
text(str): query text.
topk(int): topk.
filters(MetadataFilters): metadata filters.
Returns:
dict: query result.
"""
if not text:
return {}
where_filters = self.convert_metadata_filters(filters) if filters else None
if self.embeddings is None:
raise ValueError("Chroma Embeddings is None")
query_embedding = self.embeddings.embed_query(text)
return self._collection.query(
query_embeddings=query_embedding,
n_results=topk,
where=where_filters,
)

def _clean_persist_folder(self):
"""Clean persist folder."""
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
Expand Down Expand Up @@ -230,3 +307,14 @@ def _convert_chroma_filter_operator(operator: str) -> str:
return "$lte"
else:
raise ValueError(f"Chroma Where operator {operator} not supported")


def _transform_chroma_metadata(
metadata: Dict[str, Any]
) -> Mapping[str, str | int | float | bool]:
"""Transform metadata to Chroma metadata."""
transformed = {}
for key, value in metadata.items():
if isinstance(value, (str, int, float, bool)):
transformed[key] = value
return transformed
2 changes: 1 addition & 1 deletion dbgpt/storage/vector_store/pgvector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, vector_store_config: PGVectorConfig) -> None:
embedding_function=self.embeddings,
collection_name=self.collection_name,
connection_string=self.connection_string,
)
) # mypy: ignore

def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
# If you modify the version, please modify the version in the following files:
# dbgpt/_version.py
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.5")
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.6")

BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
LLAMA_CPP_GPU_ACCELERATION = (
Expand Down Expand Up @@ -499,7 +499,6 @@ def knowledge_requires():
pip install "dbgpt[rag]"
"""
setup_spec.extras["rag"] = setup_spec.extras["vstore"] + [
"langchain>=0.0.286",
"spacy>=3.7",
"markdown",
"bs4",
Expand Down

0 comments on commit d313155

Please sign in to comment.