Skip to content

Commit

Permalink
Feat rdb summary wide table (#2035)
Browse files Browse the repository at this point in the history
Co-authored-by: dongzhancai1 <[email protected]>
Co-authored-by: dong <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2024
1 parent 7f4b5e7 commit 9b0161e
Show file tree
Hide file tree
Showing 17 changed files with 948 additions and 243 deletions.
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ QUANTIZE_8bit=True
#** EMBEDDING SETTINGS **#
#*******************************************************************#
EMBEDDING_MODEL=text2vec
EMBEDDING_MODEL_MAX_SEQ_LEN=512
#EMBEDDING_MODEL=m3e-large
#EMBEDDING_MODEL=bge-large-en
#EMBEDDING_MODEL=bge-large-zh
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ def __init__(self) -> None:

# EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(
os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)
)
# Rerank model configuration
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")
Expand Down
3 changes: 0 additions & 3 deletions dbgpt/app/scene/chat_db/professional_qa/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ async def generate_input_values(self) -> Dict:
if self.db_name:
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
# table_infos = client.get_db_summary(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
Expand Down
92 changes: 81 additions & 11 deletions dbgpt/rag/assembler/db_schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""DBSchemaAssembler."""
import os
from typing import Any, List, Optional

from dbgpt.core import Chunk
from dbgpt.core import Chunk, Embeddings
from dbgpt.datasource.base import BaseConnector

from ...serve.rag.connector import VectorStoreConnector
from ...storage.vector_store.base import VectorStoreConfig
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..knowledge.datasource import DatasourceKnowledge
from ..retriever.db_schema import DBSchemaRetriever

Expand Down Expand Up @@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler):
def __init__(
self,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use.
table_vector_store_connector: VectorStoreConnector to load
and retrieve table info.
field_vector_store_connector: VectorStoreConnector to load
and retrieve field info.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
knowledge = DatasourceKnowledge(connector)
self._connector = connector
self._index_store = index_store
self._table_vector_store_connector = table_vector_store_connector

field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)

self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)

if (
embeddings
and self._table_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._table_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
if (
embeddings
and self._field_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._field_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
Expand All @@ -62,23 +106,36 @@ def __init__(
def load_from_connection(
cls,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path.
Args:
connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use.
table_vector_store_connector: used to load table chunks.
field_vector_store_connector: used to load field chunks
if field in table is too much.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
max_seq_length: Embedding model max sequence length
Returns:
DBSchemaAssembler
"""
return cls(
connector=connector,
index_store=index_store,
table_vector_store_connector=table_vector_store_connector,
field_vector_store_connector=field_vector_store_connector,
embedding_model=embedding_model,
chunk_parameters=chunk_parameters,
embeddings=embeddings,
max_seq_length=max_seq_length,
)

def get_chunks(self) -> List[Chunk]:
Expand All @@ -91,7 +148,19 @@ def persist(self, **kwargs: Any) -> List[str]:
Returns:
List[str]: List of chunk ids.
"""
return self._index_store.load_document(self._chunks)
table_chunks, field_chunks = [], []
for chunk in self._chunks:
metadata = chunk.metadata
if metadata.get("separated"):
if metadata.get("part") == "table":
table_chunks.append(chunk)
else:
field_chunks.append(chunk)
else:
table_chunks.append(chunk)

self._field_vector_store_connector.load_document(field_chunks)
return self._table_vector_store_connector.load_document(table_chunks)

def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
Expand All @@ -110,5 +179,6 @@ def as_retriever(self, top_k: int = 4, **kwargs) -> DBSchemaRetriever:
top_k=top_k,
connector=self._connector,
is_embeddings=True,
index_store=self._index_store,
table_vector_store_connector=self._table_vector_store_connector,
field_vector_store_connector=self._field_vector_store_connector,
)
151 changes: 96 additions & 55 deletions dbgpt/rag/assembler/tests/test_db_struct_assembler.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,117 @@
from unittest.mock import MagicMock
from typing import List
from unittest.mock import MagicMock, patch

import pytest

from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.chroma_store import ChromaStore
import dbgpt
from dbgpt.core import Chunk
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary


@pytest.fixture
def mock_db_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
return MagicMock()


@pytest.fixture
def mock_chunk_parameters():
return MagicMock(spec=ChunkParameters)
def mock_table_vector_store_connector():
mock_connector = MagicMock()
mock_connector.vector_store_config.name = "table_name"
chunk = Chunk(
content="table_name: user\ncomment: user about dbgpt",
metadata={
"field_num": 6,
"part": "table",
"separated": 1,
"table_name": "user",
},
)
mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk])
return mock_connector


@pytest.fixture
def mock_embedding_factory():
return MagicMock(spec=EmbeddingFactory)
def mock_field_vector_store_connector():
mock_connector = MagicMock()
chunk1 = Chunk(
content="name,age",
metadata={
"field_num": 6,
"part": "field",
"part_index": 0,
"separated": 1,
"table_name": "user",
},
)
chunk2 = Chunk(
content="address,gender",
metadata={
"field_num": 6,
"part": "field",
"part_index": 1,
"separated": 1,
"table_name": "user",
},
)
chunk3 = Chunk(
content="mail,phone",
metadata={
"field_num": 6,
"part": "field",
"part_index": 2,
"separated": 1,
"table_name": "user",
},
)
mock_connector.similar_search_with_scores = MagicMock(
return_value=[chunk1, chunk2, chunk3]
)
return mock_connector


@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=ChromaStore)
def dbstruct_retriever(
mock_db_connection,
mock_table_vector_store_connector,
mock_field_vector_store_connector,
):
return DBSchemaRetriever(
connector=mock_db_connection,
table_vector_store_connector=mock_table_vector_store_connector,
field_vector_store_connector=mock_field_vector_store_connector,
separator="--table-field-separator--",
)


@pytest.fixture
def mock_knowledge():
return MagicMock(spec=Knowledge)
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)


def test_load_knowledge(
mock_db_connection,
mock_knowledge,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = EmbeddingAssembler(
knowledge=mock_knowledge,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
index_store=mock_vector_store_connector,
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)


def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
assembler.load_knowledge(knowledge=mock_knowledge)
assert len(assembler._chunks) == 0
Loading

0 comments on commit 9b0161e

Please sign in to comment.