Skip to content

Commit

Permalink
fix:graph retrieve bug (eosphoros-ai#1884)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt authored Aug 28, 2024
1 parent 1cb7e35 commit bb5d2d1
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
1 change: 1 addition & 0 deletions dbgpt/app/scene/chat_knowledge/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self, chat_param: Dict):
top_k=retriever_top_k,
query_rewrite=query_rewrite,
rerank=reranker,
llm_model=self.llm_model,
)

self.prompt_template.template_is_strict = False
Expand Down
23 changes: 18 additions & 5 deletions dbgpt/serve/rag/retriever/knowledge_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker
from dbgpt.rag.retriever.base import BaseRetriever
Expand All @@ -26,6 +28,7 @@ def __init__(
top_k: Optional[int] = 4,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None,
llm_model: Optional[str] = None,
):
"""
Args:
Expand All @@ -40,6 +43,7 @@ def __init__(
self._top_k = top_k
self._query_rewrite = query_rewrite
self._rerank = rerank
self._llm_model = llm_model
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
Expand All @@ -50,9 +54,19 @@ def __init__(

space_dao = KnowledgeSpaceDao()
space = space_dao.get_one({"id": space_id})
config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn)
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
llm_client = DefaultLLMClient(worker_manager=worker_manager)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=llm_client,
llm_model=self._llm_model,
)

self._vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_type=space.vector_type,
vector_store_config=config,
)
self._executor = CFG.SYSTEM_APP.get_component(
Expand Down Expand Up @@ -141,7 +155,6 @@ async def _aretrieve_with_score(
Return:
List[Chunk]: list of chunks with score.
"""
candidates_with_score = await blocking_func_to_async(
self._executor, self._retrieve_with_score, query, score_threshold, filters
return await self._retriever_chain.aretrieve_with_scores(
query, score_threshold, filters
)
return candidates_with_score
24 changes: 15 additions & 9 deletions dbgpt/serve/rag/retriever/retriever_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,20 @@ def _retrieve(
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
"""Async retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
candidates = await blocking_func_to_async(
self._executor, self._retrieve, query, filters
)
return candidates
for retriever in self._retrievers:
candidates = await retriever.aretrieve(
query=query, filters=filters
)
if candidates:
return candidates
return []

def _retrieve_with_score(
self,
Expand Down Expand Up @@ -85,7 +88,10 @@ async def _aretrieve_with_score(
Return:
List[Chunk]: list of chunks with score
"""
candidates_with_score = await blocking_func_to_async(
self._executor, self._retrieve_with_score, query, score_threshold, filters
)
return candidates_with_score
for retriever in self._retrievers:
candidates_with_scores = await retriever.aretrieve_with_scores(
query=query, score_threshold=score_threshold, filters=filters
)
if candidates_with_scores:
return candidates_with_scores
return []

0 comments on commit bb5d2d1

Please sign in to comment.