From 9b0161e521e2e24cf96248e01c8b13583437a1c9 Mon Sep 17 00:00:00 2001 From: Cooper <42096311+FOkvj@users.noreply.github.com> Date: Wed, 18 Dec 2024 20:34:21 +0800 Subject: [PATCH] Feat rdb summary wide table (#2035) Co-authored-by: dongzhancai1 Co-authored-by: dong --- .env.template | 1 + dbgpt/_private/config.py | 3 + .../app/scene/chat_db/professional_qa/chat.py | 3 - dbgpt/rag/assembler/db_schema.py | 92 ++++- .../tests/test_db_struct_assembler.py | 151 ++++++--- .../tests/test_embedding_assembler.py | 48 ++- dbgpt/rag/knowledge/datasource.py | 34 +- dbgpt/rag/operators/db_schema.py | 35 +- dbgpt/rag/retriever/db_schema.py | 167 +++++---- dbgpt/rag/retriever/tests/test_db_struct.py | 45 ++- dbgpt/rag/summary/db_summary_client.py | 55 ++- dbgpt/rag/summary/rdbms_db_summary.py | 130 ++++++- dbgpt/rag/text_splitter/text_splitter.py | 39 +++ dbgpt/util/chat_util.py | 42 ++- .../case_3_order_wide_table_sqlite_wide.sql | 317 ++++++++++++++++++ examples/rag/db_schema_rag_example.py | 22 +- scripts/examples/load_examples.sh | 7 + 17 files changed, 948 insertions(+), 243 deletions(-) create mode 100644 docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql diff --git a/.env.template b/.env.template index 221b7ee87..d8a2d384a 100644 --- a/.env.template +++ b/.env.template @@ -66,6 +66,7 @@ QUANTIZE_8bit=True #** EMBEDDING SETTINGS **# #*******************************************************************# EMBEDDING_MODEL=text2vec +EMBEDDING_MODEL_MAX_SEQ_LEN=512 #EMBEDDING_MODEL=m3e-large #EMBEDDING_MODEL=bge-large-en #EMBEDDING_MODEL=bge-large-zh diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 82ccae656..a2ee9fe0f 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -264,6 +264,9 @@ def __init__(self) -> None: # EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") + self.EMBEDDING_MODEL_MAX_SEQ_LEN = int( + os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512) + ) # Rerank model configuration self.RERANK_MODEL = os.getenv("RERANK_MODEL") self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH") diff --git a/dbgpt/app/scene/chat_db/professional_qa/chat.py b/dbgpt/app/scene/chat_db/professional_qa/chat.py index 52ed56abc..6049bf76c 100644 --- a/dbgpt/app/scene/chat_db/professional_qa/chat.py +++ b/dbgpt/app/scene/chat_db/professional_qa/chat.py @@ -55,9 +55,6 @@ async def generate_input_values(self) -> Dict: if self.db_name: client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: - # table_infos = client.get_db_summary( - # dbname=self.db_name, query=self.current_user_input, topk=self.top_k - # ) table_infos = await blocking_func_to_async( self._executor, client.get_db_summary, diff --git a/dbgpt/rag/assembler/db_schema.py b/dbgpt/rag/assembler/db_schema.py index b55add20d..99507a2b2 100644 --- a/dbgpt/rag/assembler/db_schema.py +++ b/dbgpt/rag/assembler/db_schema.py @@ -1,12 +1,15 @@ """DBSchemaAssembler.""" +import os from typing import Any, List, Optional -from dbgpt.core import Chunk +from dbgpt.core import Chunk, Embeddings from dbgpt.datasource.base import BaseConnector +from ...serve.rag.connector import VectorStoreConnector +from ...storage.vector_store.base import VectorStoreConfig from ..assembler.base import BaseAssembler from ..chunk_manager import ChunkParameters -from ..index.base import IndexStoreBase +from ..embedding.embedding_factory import DefaultEmbeddingFactory from ..knowledge.datasource import DatasourceKnowledge from ..retriever.db_schema import DBSchemaRetriever @@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler): def __init__( self, connector: BaseConnector, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector = None, chunk_parameters: Optional[ChunkParameters] = None, + embedding_model: Optional[str] = None, + embeddings: Optional[Embeddings] = None, + max_seq_length: int = 512, **kwargs: Any, ) -> None: """Initialize with Embedding Assembler arguments. Args: connector: (BaseConnector) BaseConnector connection. - index_store: (IndexStoreBase) IndexStoreBase to use. + table_vector_store_connector: VectorStoreConnector to load + and retrieve table info. + field_vector_store_connector: VectorStoreConnector to load + and retrieve field info. chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking. embedding_model: (Optional[str]) Embedding model to use. embeddings: (Optional[Embeddings]) Embeddings to use. """ - knowledge = DatasourceKnowledge(connector) self._connector = connector - self._index_store = index_store + self._table_vector_store_connector = table_vector_store_connector + field_vector_store_config = VectorStoreConfig( + name=table_vector_store_connector.vector_store_config.name + "_field" + ) + self._field_vector_store_connector = ( + field_vector_store_connector + or VectorStoreConnector.from_default( + os.getenv("VECTOR_STORE_TYPE", "Chroma"), + self._table_vector_store_connector.current_embeddings, + vector_store_config=field_vector_store_config, + ) + ) + + self._embedding_model = embedding_model + if self._embedding_model and not embeddings: + embeddings = DefaultEmbeddingFactory( + default_model_name=self._embedding_model + ).create(self._embedding_model) + + if ( + embeddings + and self._table_vector_store_connector.vector_store_config.embedding_fn + is None + ): + self._table_vector_store_connector.vector_store_config.embedding_fn = ( + embeddings + ) + if ( + embeddings + and self._field_vector_store_connector.vector_store_config.embedding_fn + is None + ): + self._field_vector_store_connector.vector_store_config.embedding_fn = ( + embeddings + ) + knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length) super().__init__( knowledge=knowledge, chunk_parameters=chunk_parameters, @@ -62,23 +106,36 @@ def __init__( def load_from_connection( cls, connector: BaseConnector, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector = None, chunk_parameters: Optional[ChunkParameters] = None, + embedding_model: Optional[str] = None, + embeddings: Optional[Embeddings] = None, + max_seq_length: int = 512, ) -> "DBSchemaAssembler": """Load document embedding into vector store from path. Args: connector: (BaseConnector) BaseConnector connection. - index_store: (IndexStoreBase) IndexStoreBase to use. + table_vector_store_connector: used to load table chunks. + field_vector_store_connector: used to load field chunks + if field in table is too much. chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking. + embedding_model: (Optional[str]) Embedding model to use. + embeddings: (Optional[Embeddings]) Embeddings to use. + max_seq_length: Embedding model max sequence length Returns: DBSchemaAssembler """ return cls( connector=connector, - index_store=index_store, + table_vector_store_connector=table_vector_store_connector, + field_vector_store_connector=field_vector_store_connector, + embedding_model=embedding_model, chunk_parameters=chunk_parameters, + embeddings=embeddings, + max_seq_length=max_seq_length, ) def get_chunks(self) -> List[Chunk]: @@ -91,7 +148,19 @@ def persist(self, **kwargs: Any) -> List[str]: Returns: List[str]: List of chunk ids. """ - return self._index_store.load_document(self._chunks) + table_chunks, field_chunks = [], [] + for chunk in self._chunks: + metadata = chunk.metadata + if metadata.get("separated"): + if metadata.get("part") == "table": + table_chunks.append(chunk) + else: + field_chunks.append(chunk) + else: + table_chunks.append(chunk) + + self._field_vector_store_connector.load_document(field_chunks) + return self._table_vector_store_connector.load_document(table_chunks) def _extract_info(self, chunks) -> List[Chunk]: """Extract info from chunks.""" @@ -110,5 +179,6 @@ def as_retriever(self, top_k: int = 4, **kwargs) -> DBSchemaRetriever: top_k=top_k, connector=self._connector, is_embeddings=True, - index_store=self._index_store, + table_vector_store_connector=self._table_vector_store_connector, + field_vector_store_connector=self._field_vector_store_connector, ) diff --git a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py index 84638b692..598160374 100644 --- a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py +++ b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py @@ -1,76 +1,117 @@ -from unittest.mock import MagicMock +from typing import List +from unittest.mock import MagicMock, patch import pytest -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector -from dbgpt.rag.assembler.embedding import EmbeddingAssembler -from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType -from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory -from dbgpt.rag.knowledge.base import Knowledge -from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter -from dbgpt.storage.vector_store.chroma_store import ChromaStore +import dbgpt +from dbgpt.core import Chunk +from dbgpt.rag.retriever.db_schema import DBSchemaRetriever +from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary @pytest.fixture def mock_db_connection(): - """Create a temporary database connection for testing.""" - connect = SQLiteTempConnector.create_temporary_db() - connect.create_temp_tables( - { - "user": { - "columns": { - "id": "INTEGER PRIMARY KEY", - "name": "TEXT", - "age": "INTEGER", - }, - "data": [ - (1, "Tom", 10), - (2, "Jerry", 16), - (3, "Jack", 18), - (4, "Alice", 20), - (5, "Bob", 22), - ], - } - } - ) - return connect + return MagicMock() @pytest.fixture -def mock_chunk_parameters(): - return MagicMock(spec=ChunkParameters) +def mock_table_vector_store_connector(): + mock_connector = MagicMock() + mock_connector.vector_store_config.name = "table_name" + chunk = Chunk( + content="table_name: user\ncomment: user about dbgpt", + metadata={ + "field_num": 6, + "part": "table", + "separated": 1, + "table_name": "user", + }, + ) + mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk]) + return mock_connector @pytest.fixture -def mock_embedding_factory(): - return MagicMock(spec=EmbeddingFactory) +def mock_field_vector_store_connector(): + mock_connector = MagicMock() + chunk1 = Chunk( + content="name,age", + metadata={ + "field_num": 6, + "part": "field", + "part_index": 0, + "separated": 1, + "table_name": "user", + }, + ) + chunk2 = Chunk( + content="address,gender", + metadata={ + "field_num": 6, + "part": "field", + "part_index": 1, + "separated": 1, + "table_name": "user", + }, + ) + chunk3 = Chunk( + content="mail,phone", + metadata={ + "field_num": 6, + "part": "field", + "part_index": 2, + "separated": 1, + "table_name": "user", + }, + ) + mock_connector.similar_search_with_scores = MagicMock( + return_value=[chunk1, chunk2, chunk3] + ) + return mock_connector @pytest.fixture -def mock_vector_store_connector(): - return MagicMock(spec=ChromaStore) +def dbstruct_retriever( + mock_db_connection, + mock_table_vector_store_connector, + mock_field_vector_store_connector, +): + return DBSchemaRetriever( + connector=mock_db_connection, + table_vector_store_connector=mock_table_vector_store_connector, + field_vector_store_connector=mock_field_vector_store_connector, + separator="--table-field-separator--", + ) -@pytest.fixture -def mock_knowledge(): - return MagicMock(spec=Knowledge) +def mock_parse_db_summary() -> str: + """Patch _parse_db_summary method.""" + return ( + "table_name: user\ncomment: user about dbgpt\n" + "--table-field-separator--\n" + "name,age\naddress,gender\nmail,phone" + ) -def test_load_knowledge( - mock_db_connection, - mock_knowledge, - mock_chunk_parameters, - mock_embedding_factory, - mock_vector_store_connector, -): - mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" - mock_chunk_parameters.text_splitter = CharacterTextSplitter() - mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE - assembler = EmbeddingAssembler( - knowledge=mock_knowledge, - chunk_parameters=mock_chunk_parameters, - embeddings=mock_embedding_factory.create(), - index_store=mock_vector_store_connector, +# Mocking the _parse_db_summary method in your test function +@patch.object( + dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary +) +def test_retrieve_with_mocked_summary(dbstruct_retriever): + query = "Table summary" + chunks: List[Chunk] = dbstruct_retriever._retrieve(query) + assert isinstance(chunks[0], Chunk) + assert chunks[0].content == ( + "table_name: user\ncomment: user about dbgpt\n" + "--table-field-separator--\n" + "name,age\naddress,gender\nmail,phone" + ) + + +def async_mock_parse_db_summary() -> str: + """Asynchronous patch for _parse_db_summary method.""" + return ( + "table_name: user\ncomment: user about dbgpt\n" + "--table-field-separator--\n" + "name,age\naddress,gender\nmail,phone" ) - assembler.load_knowledge(knowledge=mock_knowledge) - assert len(assembler._chunks) == 0 diff --git a/dbgpt/rag/assembler/tests/test_embedding_assembler.py b/dbgpt/rag/assembler/tests/test_embedding_assembler.py index 350ccad39..f2ac1577e 100644 --- a/dbgpt/rag/assembler/tests/test_embedding_assembler.py +++ b/dbgpt/rag/assembler/tests/test_embedding_assembler.py @@ -5,9 +5,9 @@ from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType -from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory -from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter -from dbgpt.storage.vector_store.chroma_store import ChromaStore +from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory +from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter +from dbgpt.serve.rag.connector import VectorStoreConnector @pytest.fixture @@ -21,14 +21,22 @@ def mock_db_connection(): "id": "INTEGER PRIMARY KEY", "name": "TEXT", "age": "INTEGER", - }, - "data": [ - (1, "Tom", 10), - (2, "Jerry", 16), - (3, "Jack", 18), - (4, "Alice", 20), - (5, "Bob", 22), - ], + "address": "TEXT", + "phone": "TEXT", + "email": "TEXT", + "gender": "TEXT", + "birthdate": "TEXT", + "occupation": "TEXT", + "education": "TEXT", + "marital_status": "TEXT", + "nationality": "TEXT", + "height": "REAL", + "weight": "REAL", + "blood_type": "TEXT", + "emergency_contact": "TEXT", + "created_at": "TEXT", + "updated_at": "TEXT", + } } } ) @@ -46,23 +54,29 @@ def mock_embedding_factory(): @pytest.fixture -def mock_vector_store_connector(): - return MagicMock(spec=ChromaStore) +def mock_table_vector_store_connector(): + mock_connector = MagicMock(spec=VectorStoreConnector) + mock_connector.vector_store_config.name = "table_vector_store_name" + mock_connector.current_embeddings = DefaultEmbeddings() + return mock_connector def test_load_knowledge( mock_db_connection, mock_chunk_parameters, mock_embedding_factory, - mock_vector_store_connector, + mock_table_vector_store_connector, ): mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" - mock_chunk_parameters.text_splitter = CharacterTextSplitter() + mock_chunk_parameters.text_splitter = RDBTextSplitter( + separator="--table-field-separator--" + ) mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE assembler = DBSchemaAssembler( connector=mock_db_connection, chunk_parameters=mock_chunk_parameters, embeddings=mock_embedding_factory.create(), - index_store=mock_vector_store_connector, + table_vector_store_connector=mock_table_vector_store_connector, + max_seq_length=10, ) - assert len(assembler._chunks) == 1 + assert len(assembler._chunks) > 1 diff --git a/dbgpt/rag/knowledge/datasource.py b/dbgpt/rag/knowledge/datasource.py index 78ae045dd..504e18dc4 100644 --- a/dbgpt/rag/knowledge/datasource.py +++ b/dbgpt/rag/knowledge/datasource.py @@ -5,7 +5,7 @@ from dbgpt.datasource import BaseConnector from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary -from ..summary.rdbms_db_summary import _parse_db_summary +from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType @@ -15,9 +15,11 @@ class DatasourceKnowledge(Knowledge): def __init__( self, connector: BaseConnector, - summary_template: str = "{table_name}({columns})", + summary_template: str = "table_name: {table_name}", + separator: str = "--table-field-separator--", knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, metadata: Optional[Dict[str, Union[str, List[str]]]] = None, + model_dimension: int = 512, **kwargs: Any, ) -> None: """Create Datasource Knowledge with Knowledge arguments. @@ -25,11 +27,17 @@ def __init__( Args: connector(BaseConnector): connector summary_template(str, optional): summary template + separator(str, optional): separator used to separate + table's basic info and fields. + defaults `-- table-field-separator--` knowledge_type(KnowledgeType, optional): knowledge type metadata(Dict[str, Union[str, List[str]], optional): metadata + model_dimension(int, optional): The threshold for splitting field string """ + self._separator = separator self._connector = connector self._summary_template = summary_template + self._model_dimension = model_dimension super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs) def _load(self) -> List[Document]: @@ -37,13 +45,23 @@ def _load(self) -> List[Document]: docs = [] if self._connector.is_graph_type(): db_summary = _parse_gdb_summary(self._connector, self._summary_template) + for table_summary in db_summary: + metadata = {"source": "database"} + docs.append(Document(content=table_summary, metadata=metadata)) else: - db_summary = _parse_db_summary(self._connector, self._summary_template) - for table_summary in db_summary: - metadata = {"source": "database"} - if self._metadata: - metadata.update(self._metadata) # type: ignore - docs.append(Document(content=table_summary, metadata=metadata)) + db_summary_with_metadata = _parse_db_summary_with_metadata( + self._connector, + self._summary_template, + self._separator, + self._model_dimension, + ) + for summary, table_metadata in db_summary_with_metadata: + metadata = {"source": "database"} + + if self._metadata: + metadata.update(self._metadata) # type: ignore + table_metadata.update(metadata) + docs.append(Document(content=summary, metadata=table_metadata)) return docs @classmethod diff --git a/dbgpt/rag/operators/db_schema.py b/dbgpt/rag/operators/db_schema.py index d0a7c0d9f..4e642c549 100644 --- a/dbgpt/rag/operators/db_schema.py +++ b/dbgpt/rag/operators/db_schema.py @@ -1,14 +1,15 @@ """The DBSchema Retriever Operator.""" - +import os from typing import List, Optional from dbgpt.core import Chunk from dbgpt.core.interface.operators.retriever import RetrieverOperator from dbgpt.datasource.base import BaseConnector +from dbgpt.serve.rag.connector import VectorStoreConnector +from ...storage.vector_store.base import VectorStoreConfig from ..assembler.db_schema import DBSchemaAssembler from ..chunk_manager import ChunkParameters -from ..index.base import IndexStoreBase from ..retriever.db_schema import DBSchemaRetriever from .assembler import AssemblerOperator @@ -19,13 +20,14 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]): Args: connector (BaseConnector): The connection. top_k (int, optional): The top k. Defaults to 4. - index_store (IndexStoreBase, optional): The vector store + vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None. """ def __init__( self, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector, top_k: int = 4, connector: Optional[BaseConnector] = None, **kwargs @@ -35,7 +37,8 @@ def __init__( self._retriever = DBSchemaRetriever( top_k=top_k, connector=connector, - index_store=index_store, + table_vector_store_connector=table_vector_store_connector, + field_vector_store_connector=field_vector_store_connector, ) def retrieve(self, query: str) -> List[Chunk]: @@ -53,7 +56,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]): def __init__( self, connector: BaseConnector, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector = None, chunk_parameters: Optional[ChunkParameters] = None, **kwargs ): @@ -61,14 +65,26 @@ def __init__( Args: connector (BaseConnector): The connection. - index_store (IndexStoreBase): The Storage IndexStoreBase. + vector_store_connector (VectorStoreConnector): The vector store connector. chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. """ if not chunk_parameters: chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") self._chunk_parameters = chunk_parameters - self._index_store = index_store + self._table_vector_store_connector = table_vector_store_connector + + field_vector_store_config = VectorStoreConfig( + name=table_vector_store_connector.vector_store_config.name + "_field" + ) + self._field_vector_store_connector = ( + field_vector_store_connector + or VectorStoreConnector.from_default( + os.getenv("VECTOR_STORE_TYPE", "Chroma"), + self._table_vector_store_connector.current_embeddings, + vector_store_config=field_vector_store_config, + ) + ) self._connector = connector super().__init__(**kwargs) @@ -84,7 +100,8 @@ def assemble(self, dummy_value) -> List[Chunk]: assembler = DBSchemaAssembler.load_from_connection( connector=self._connector, chunk_parameters=self._chunk_parameters, - index_store=self._index_store, + table_vector_store_connector=self._table_vector_store_connector, + field_vector_store_connector=self._field_vector_store_connector, ) assembler.persist() return assembler.get_chunks() diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index 9bced9267..1326e2385 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -1,18 +1,23 @@ """DBSchema retriever.""" +import logging +import os +from typing import List, Optional -from functools import reduce -from typing import List, Optional, cast - +from dbgpt._private.config import Config from dbgpt.core import Chunk from dbgpt.datasource.base import BaseConnector -from dbgpt.rag.index.base import IndexStoreBase from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker -from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary -from dbgpt.storage.vector_store.filters import MetadataFilters -from dbgpt.util.chat_util import run_async_tasks +from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary +from dbgpt.serve.rag.connector import VectorStoreConnector +from dbgpt.storage.vector_store.base import VectorStoreConfig +from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters +from dbgpt.util.chat_util import run_tasks from dbgpt.util.executor_utils import blocking_func_to_async_no_executor -from dbgpt.util.tracer import root_tracer + +logger = logging.getLogger(__name__) + +CFG = Config() class DBSchemaRetriever(BaseRetriever): @@ -20,7 +25,9 @@ class DBSchemaRetriever(BaseRetriever): def __init__( self, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector = None, + separator: str = "--table-field-separator--", top_k: int = 4, connector: Optional[BaseConnector] = None, query_rewrite: bool = False, @@ -30,7 +37,11 @@ def __init__( """Create DBSchemaRetriever. Args: - index_store(IndexStore): index connector + table_vector_store_connector: VectorStoreConnector + to load and retrieve table info. + field_vector_store_connector: VectorStoreConnector + to load and retrieve field info. + separator: field/table separator top_k (int): top k connector (Optional[BaseConnector]): RDBMSConnector. query_rewrite (bool): query rewrite @@ -70,34 +81,42 @@ def _create_temporary_connection(): connector = _create_temporary_connection() + vector_store_config = ChromaVectorConfig(name="vector_store_name") + embedding_model_path = "{your_embedding_model_path}" embedding_fn = embedding_factory.create(model_name=embedding_model_path) - config = ChromaVectorConfig( - persist_path=PILOT_PATH, - name="dbschema_rag_test", - embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join( - MODEL_PATH, "text2vec-large-chinese" - ), - ).create(), + vector_connector = VectorStoreConnector.from_default( + "Chroma", + vector_store_config=vector_store_config, + embedding_fn=embedding_fn, ) - - vector_store = ChromaStore(config) # get db struct retriever retriever = DBSchemaRetriever( top_k=3, - index_store=vector_store, + vector_store_connector=vector_connector, connector=connector, ) chunks = retriever.retrieve("show columns from table") result = [chunk.content for chunk in chunks] print(f"db struct rag example results:{result}") """ + self._separator = separator self._top_k = top_k self._connector = connector self._query_rewrite = query_rewrite - self._index_store = index_store + self._table_vector_store_connector = table_vector_store_connector + field_vector_store_config = VectorStoreConfig( + name=table_vector_store_connector.vector_store_config.name + "_field" + ) + self._field_vector_store_connector = ( + field_vector_store_connector + or VectorStoreConnector.from_default( + os.getenv("VECTOR_STORE_TYPE", "Chroma"), + self._table_vector_store_connector.current_embeddings, + vector_store_config=field_vector_store_config, + ) + ) self._need_embeddings = False - if self._index_store: + if self._table_vector_store_connector: self._need_embeddings = True self._rerank = rerank or DefaultRanker(self._top_k) @@ -114,15 +133,8 @@ def _retrieve( List[Chunk]: list of chunks """ if self._need_embeddings: - queries = [query] - candidates = [ - self._index_store.similar_search(query, self._top_k, filters) - for query in queries - ] - return cast(List[Chunk], reduce(lambda x, y: x + y, candidates)) + return self._similarity_search(query, filters) else: - if not self._connector: - raise RuntimeError("RDBMSConnector connection is required.") table_summaries = _parse_db_summary(self._connector) return [Chunk(content=table_summary) for table_summary in table_summaries] @@ -156,30 +168,11 @@ async def _aretrieve( Returns: List[Chunk]: list of chunks """ - if self._need_embeddings: - queries = [query] - candidates = [ - self._similarity_search( - query, filters, root_tracer.get_current_span_id() - ) - for query in queries - ] - result_candidates = await run_async_tasks( - tasks=candidates, concurrency_limit=1 - ) - return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates)) - else: - from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401 - _parse_db_summary, - ) - - table_summaries = await run_async_tasks( - tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())], - concurrency_limit=1, - ) - return [ - Chunk(content=table_summary) for table_summary in table_summaries[0] - ] + return await blocking_func_to_async_no_executor( + func=self._retrieve, + query=query, + filters=filters, + ) async def _aretrieve_with_score( self, @@ -196,34 +189,40 @@ async def _aretrieve_with_score( """ return await self._aretrieve(query, filters) - async def _similarity_search( - self, - query, - filters: Optional[MetadataFilters] = None, - parent_span_id: Optional[str] = None, + def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk: + metadata = table_chunk.metadata + metadata["part"] = "field" + filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()] + field_chunks = self._field_vector_store_connector.similar_search_with_scores( + query, self._top_k, 0, MetadataFilters(filters=filters) + ) + field_contents = [chunk.content for chunk in field_chunks] + table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents) + return table_chunk + + def _similarity_search( + self, query, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: """Similar search.""" - with root_tracer.start_span( - "dbgpt.rag.retriever.db_schema._similarity_search", - parent_span_id, - metadata={"query": query}, - ): - return await blocking_func_to_async_no_executor( - self._index_store.similar_search, query, self._top_k, filters - ) - - async def _aparse_db_summary( - self, parent_span_id: Optional[str] = None - ) -> List[str]: - """Similar search.""" - from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary - - if not self._connector: - raise RuntimeError("RDBMSConnector connection is required.") - with root_tracer.start_span( - "dbgpt.rag.retriever.db_schema._aparse_db_summary", - parent_span_id, - ): - return await blocking_func_to_async_no_executor( - _parse_db_summary, self._connector - ) + table_chunks = self._table_vector_store_connector.similar_search_with_scores( + query, self._top_k, 0, filters + ) + + not_sep_chunks = [ + chunk for chunk in table_chunks if not chunk.metadata.get("separated") + ] + separated_chunks = [ + chunk for chunk in table_chunks if chunk.metadata.get("separated") + ] + if not separated_chunks: + return not_sep_chunks + + # Create tasks list + tasks = [ + lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks + ] + # Run tasks concurrently + separated_result = run_tasks(tasks, concurrency_limit=3) + + # Combine and return results + return not_sep_chunks + separated_result diff --git a/dbgpt/rag/retriever/tests/test_db_struct.py b/dbgpt/rag/retriever/tests/test_db_struct.py index 0b667f69e..f34a4070a 100644 --- a/dbgpt/rag/retriever/tests/test_db_struct.py +++ b/dbgpt/rag/retriever/tests/test_db_struct.py @@ -15,42 +15,53 @@ def mock_db_connection(): @pytest.fixture -def mock_vector_store_connector(): +def mock_table_vector_store_connector(): mock_connector = MagicMock() - mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4 + mock_connector.vector_store_config.name = "table_name" + mock_connector.similar_search_with_scores.return_value = [ + Chunk(content="Table summary") + ] * 4 return mock_connector @pytest.fixture -def db_struct_retriever(mock_db_connection, mock_vector_store_connector): +def mock_field_vector_store_connector(): + mock_connector = MagicMock() + mock_connector.similar_search_with_scores.return_value = [ + Chunk(content="Field summary") + ] * 4 + return mock_connector + + +@pytest.fixture +def dbstruct_retriever( + mock_db_connection, + mock_table_vector_store_connector, + mock_field_vector_store_connector, +): return DBSchemaRetriever( connector=mock_db_connection, - index_store=mock_vector_store_connector, + table_vector_store_connector=mock_table_vector_store_connector, + field_vector_store_connector=mock_field_vector_store_connector, ) -def mock_parse_db_summary(conn) -> List[str]: +def mock_parse_db_summary() -> str: """Patch _parse_db_summary method.""" - return ["Table summary"] + return "Table summary" # Mocking the _parse_db_summary method in your test function @patch.object( dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary ) -def test_retrieve_with_mocked_summary(db_struct_retriever): +def test_retrieve_with_mocked_summary(dbstruct_retriever): query = "Table summary" - chunks: List[Chunk] = db_struct_retriever._retrieve(query) + chunks: List[Chunk] = dbstruct_retriever._retrieve(query) assert isinstance(chunks[0], Chunk) assert chunks[0].content == "Table summary" -@pytest.mark.asyncio -@patch.object( - dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary -) -async def test_aretrieve_with_mocked_summary(db_struct_retriever): - query = "Table summary" - chunks: List[Chunk] = await db_struct_retriever._aretrieve(query) - assert isinstance(chunks[0], Chunk) - assert chunks[0].content == "Table summary" +async def async_mock_parse_db_summary() -> str: + """Asynchronous patch for _parse_db_summary method.""" + return "Table summary" diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 8ce9a79e6..073c072bc 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -2,13 +2,15 @@ import logging import traceback -from typing import List from dbgpt._private.config import Config from dbgpt.component import SystemApp from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG +from dbgpt.rag import ChunkParameters from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary +from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter +from dbgpt.serve.rag.connector import VectorStoreConnector logger = logging.getLogger(__name__) @@ -47,22 +49,26 @@ def db_summary_embedding(self, dbname, db_type): logger.info("db summary embedding success") - def get_db_summary(self, dbname, query, topk) -> List[str]: + def get_db_summary(self, dbname, query, topk): """Get user query related tables info.""" - from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig - vector_store_config = VectorStoreConfig(name=dbname + "_profile") - vector_connector = VectorStoreConnector.from_default( + vector_store_name = dbname + "_profile" + table_vector_store_config = VectorStoreConfig(name=vector_store_name) + table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, - embedding_fn=self.embeddings, - vector_store_config=vector_store_config, + self.embeddings, + vector_store_config=table_vector_store_config, ) + from dbgpt.rag.retriever.db_schema import DBSchemaRetriever retriever = DBSchemaRetriever( - top_k=topk, index_store=vector_connector.index_client + top_k=topk, + table_vector_store_connector=table_vector_connector, + separator="--table-field-separator--", ) + table_docs = retriever.retrieve(query) ans = [d.content for d in table_docs] return ans @@ -92,18 +98,23 @@ def init_db_profile(self, db_summary_client, dbname): from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig - vector_store_config = VectorStoreConfig(name=vector_store_name) - vector_connector = VectorStoreConnector.from_default( + table_vector_store_config = VectorStoreConfig(name=vector_store_name) + table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, self.embeddings, - vector_store_config=vector_store_config, + vector_store_config=table_vector_store_config, ) - if not vector_connector.vector_name_exists(): + if not table_vector_connector.vector_name_exists(): from dbgpt.rag.assembler.db_schema import DBSchemaAssembler + chunk_parameters = ChunkParameters( + text_splitter=RDBTextSplitter(separator="--table-field-separator--") + ) db_assembler = DBSchemaAssembler.load_from_connection( connector=db_summary_client.db, - index_store=vector_connector.index_client, + table_vector_store_connector=table_vector_connector, + chunk_parameters=chunk_parameters, + max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN, ) if len(db_assembler.get_chunks()) > 0: @@ -115,16 +126,26 @@ def init_db_profile(self, db_summary_client, dbname): def delete_db_profile(self, dbname): """Delete db profile.""" vector_store_name = dbname + "_profile" + table_vector_store_name = dbname + "_profile" + field_vector_store_name = dbname + "_profile_field" from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig - vector_store_config = VectorStoreConfig(name=vector_store_name) - vector_connector = VectorStoreConnector.from_default( + table_vector_store_config = VectorStoreConfig(name=vector_store_name) + field_vector_store_config = VectorStoreConfig(name=field_vector_store_name) + table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, self.embeddings, - vector_store_config=vector_store_config, + vector_store_config=table_vector_store_config, ) - vector_connector.delete_vector_name(vector_store_name) + field_vector_connector = VectorStoreConnector.from_default( + CFG.VECTOR_STORE_TYPE, + self.embeddings, + vector_store_config=field_vector_store_config, + ) + + table_vector_connector.delete_vector_name(table_vector_store_name) + field_vector_connector.delete_vector_name(field_vector_store_name) logger.info(f"delete db profile {dbname} success") @staticmethod diff --git a/dbgpt/rag/summary/rdbms_db_summary.py b/dbgpt/rag/summary/rdbms_db_summary.py index 337d3851b..c2786da51 100644 --- a/dbgpt/rag/summary/rdbms_db_summary.py +++ b/dbgpt/rag/summary/rdbms_db_summary.py @@ -1,6 +1,6 @@ """Summary for rdbms database.""" import re -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from dbgpt._private.config import Config from dbgpt.datasource import BaseConnector @@ -80,6 +80,134 @@ def _parse_db_summary( return table_info_summaries +def _parse_db_summary_with_metadata( + conn: BaseConnector, + summary_template: str = "table_name: {table_name}", + separator: str = "--table-field-separator--", + model_dimension: int = 512, +) -> List[Tuple[str, Dict[str, Any]]]: + """Get db summary for database. + + Args: + conn (BaseConnector): database connection + summary_template (str): summary template + separator(str, optional): separator used to separate table's + basic info and fields. defaults to `-- table-field-separator--` + model_dimension(int, optional): The threshold for splitting field string + """ + tables = conn.get_table_names() + table_info_summaries = [ + _parse_table_summary_with_metadata( + conn, summary_template, separator, table_name, model_dimension + ) + for table_name in tables + ] + return table_info_summaries + + +def _split_columns_str(columns: List[str], model_dimension: int): + """Split columns str. + + Args: + columns (List[str]): fields string + model_dimension (int, optional): The threshold for splitting field string. + """ + result = [] + current_string = "" + current_length = 0 + + for element_str in columns: + element_length = len(element_str) + + # If adding the current element's length would exceed the threshold, + # add the current string to results and reset + if current_length + element_length > model_dimension: + result.append(current_string.strip()) # Remove trailing spaces + current_string = element_str + current_length = element_length + else: + # If current string is empty, add element directly + if current_string: + current_string += "," + element_str + else: + current_string = element_str + current_length += element_length + 1 # Add length of space + + # Handle the last string segment + if current_string: + result.append(current_string.strip()) + + return result + + +def _parse_table_summary_with_metadata( + conn: BaseConnector, + summary_template: str, + separator, + table_name: str, + model_dimension=512, +) -> Tuple[str, Dict[str, Any]]: + """Get table summary for table. + + Args: + conn (BaseConnector): database connection + summary_template (str): summary template + separator(str, optional): separator used to separate table's + basic info and fields. defaults to `-- table-field-separator--` + model_dimension(int, optional): The threshold for splitting field string + + Examples: + metadata: {'table_name': 'asd', 'separated': 0/1} + + table_name: table1 + table_comment: comment + index_keys: keys + --table-field-separator-- + (column1,comment), (column2, comment), (column3, comment) + (column4,comment), (column5, comment), (column6, comment) + """ + columns = [] + metadata = {"table_name": table_name, "separated": 0} + for column in conn.get_columns(table_name): + if column.get("comment"): + columns.append(f"{column['name']} ({column.get('comment')})") + else: + columns.append(f"{column['name']}") + metadata.update({"field_num": len(columns)}) + separated_columns = _split_columns_str(columns, model_dimension=model_dimension) + if len(separated_columns) > 1: + metadata["separated"] = 1 + column_str = "\n".join(separated_columns) + # Obtain index information + index_keys = [] + raw_indexes = conn.get_indexes(table_name) + for index in raw_indexes: + if isinstance(index, tuple): # Process tuple type index information + index_name, index_creation_command = index + # Extract column names using re + matched_columns = re.findall(r"\(([^)]+)\)", index_creation_command) + if matched_columns: + key_str = ", ".join(matched_columns) + index_keys.append(f"{index_name}(`{key_str}`) ") + else: + key_str = ", ".join(index["column_names"]) + index_keys.append(f"{index['name']}(`{key_str}`) ") + table_str = summary_template.format(table_name=table_name) + + try: + comment = conn.get_table_comment(table_name) + except Exception: + comment = dict(text=None) + if comment.get("text"): + table_str += f"\ntable_comment: {comment.get('text')}" + + if len(index_keys) > 0: + index_key_str = ", ".join(index_keys) + table_str += f"\nindex_keys: {index_key_str}" + table_str += f"\n{separator}\n{column_str}" + return table_str, metadata + + def _parse_table_summary( conn: BaseConnector, summary_template: str, table_name: str ) -> str: diff --git a/dbgpt/rag/text_splitter/text_splitter.py b/dbgpt/rag/text_splitter/text_splitter.py index f4374e65e..9f9a882b6 100644 --- a/dbgpt/rag/text_splitter/text_splitter.py +++ b/dbgpt/rag/text_splitter/text_splitter.py @@ -912,3 +912,42 @@ def create_documents( new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i])) chunks.append(new_doc) return chunks + + +class RDBTextSplitter(TextSplitter): + """Split relational database tables and fields.""" + + def __init__(self, **kwargs): + """Create a new TextSplitter.""" + super().__init__(**kwargs) + + def split_text(self, text: str, **kwargs): + """Split text into a couple of parts.""" + pass + + def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]: + """Split document into chunks.""" + chunks = [] + for doc in documents: + metadata = doc.metadata + content = doc.content + if metadata.get("separated"): + # separate table and field + parts = content.split(self._separator) + table_part, field_part = parts[0], parts[1] + table_metadata, field_metadata = copy.deepcopy(metadata), copy.deepcopy( + metadata + ) + table_metadata["part"] = "table" # identify of table_chunk + field_metadata["part"] = "field" # identify of field_chunk + table_chunk = Chunk(content=table_part, metadata=table_metadata) + chunks.append(table_chunk) + field_parts = field_part.split("\n") + for i, sub_part in enumerate(field_parts): + sub_metadata = copy.deepcopy(field_metadata) + sub_metadata["part_index"] = i + field_chunk = Chunk(content=sub_part, metadata=sub_metadata) + chunks.append(field_chunk) + else: + chunks.append(Chunk(content=content, metadata=metadata)) + return chunks diff --git a/dbgpt/util/chat_util.py b/dbgpt/util/chat_util.py index 490f21a5f..ffb170093 100644 --- a/dbgpt/util/chat_util.py +++ b/dbgpt/util/chat_util.py @@ -1,5 +1,6 @@ import asyncio -from typing import Any, Coroutine, List +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Callable, Coroutine, List async def llm_chat_response_nostream(chat_scene: str, **chat_param): @@ -47,13 +48,34 @@ async def _execute_task(task): def run_tasks( - tasks: List[Coroutine], + tasks: List[Callable], + concurrency_limit: int = None, ) -> List[Any]: - """Run a list of async tasks.""" - tasks_to_execute: List[Any] = tasks - - async def _gather() -> List[Any]: - return await asyncio.gather(*tasks_to_execute) - - outputs: List[Any] = asyncio.run(_gather()) - return outputs + """ + Run a list of tasks concurrently using a thread pool. + + Args: + tasks: List of callable functions to execute + concurrency_limit: Maximum number of concurrent threads (optional) + + Returns: + List of results from all tasks in the order they were submitted + """ + max_workers = concurrency_limit if concurrency_limit else None + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks and get futures + futures = [executor.submit(task) for task in tasks] + + # Collect results in order, raising any exceptions + results = [] + for future in futures: + try: + results.append(future.result()) + except Exception as e: + # Cancel any pending futures + for f in futures: + f.cancel() + raise e + + return results diff --git a/docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql b/docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql new file mode 100644 index 000000000..094d7f1f0 --- /dev/null +++ b/docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql @@ -0,0 +1,317 @@ +CREATE TABLE order_wide_table ( + + -- order_base + order_id TEXT, -- 订单ID + order_no TEXT, -- 订单编号 + parent_order_no TEXT, -- 父订单编号 + order_type INTEGER, -- 订单类型:1实物2虚拟3混合 + order_status INTEGER, -- 订单状态 + order_source TEXT, -- 订单来源 + order_source_detail TEXT, -- 订单来源详情 + create_time DATETIME, -- 创建时间 + pay_time DATETIME, -- 支付时间 + finish_time DATETIME, -- 完成时间 + close_time DATETIME, -- 关闭时间 + cancel_time DATETIME, -- 取消时间 + cancel_reason TEXT, -- 取消原因 + order_remark TEXT, -- 订单备注 + seller_remark TEXT, -- 卖家备注 + buyer_remark TEXT, -- 买家备注 + is_deleted INTEGER, -- 是否删除 + delete_time DATETIME, -- 删除时间 + order_ip TEXT, -- 下单IP + order_platform TEXT, -- 下单平台 + order_device TEXT, -- 下单设备 + order_app_version TEXT, -- APP版本号 + + -- order_amount + currency TEXT, -- 货币类型 + exchange_rate REAL, -- 汇率 + original_amount REAL, -- 原始金额 + discount_amount REAL, -- 优惠金额 + coupon_amount REAL, -- 优惠券金额 + points_amount REAL, -- 积分抵扣金额 + shipping_amount REAL, -- 运费 + insurance_amount REAL, -- 保价费 + tax_amount REAL, -- 税费 + tariff_amount REAL, -- 关税 + payment_amount REAL, -- 实付金额 + commission_amount REAL, -- 佣金金额 + platform_fee REAL, -- 平台费用 + seller_income REAL, -- 卖家实收 + payment_currency TEXT, -- 支付货币 + payment_exchange_rate REAL, -- 支付汇率 + + -- user_info + user_id TEXT, -- 用户ID + user_name TEXT, -- 用户名 + user_nickname TEXT, -- 用户昵称 + user_level INTEGER, -- 用户等级 + user_type INTEGER, -- 用户类型 + register_time DATETIME, -- 注册时间 + register_source TEXT, -- 注册来源 + mobile TEXT, -- 手机号 + mobile_area TEXT, -- 手机号区号 + email TEXT, -- 邮箱 + is_vip INTEGER, -- 是否VIP + vip_level INTEGER, -- VIP等级 + vip_expire_time DATETIME, -- VIP过期时间 + user_age INTEGER, -- 用户年龄 + user_gender INTEGER, -- 用户性别 + user_birthday DATE, -- 用户生日 + user_avatar TEXT, -- 用户头像 + user_province TEXT, -- 用户所在省 + user_city TEXT, -- 用户所在市 + user_district TEXT, -- 用户所在区 + last_login_time DATETIME, -- 最后登录时间 + last_login_ip TEXT, -- 最后登录IP + user_credit_score INTEGER, -- 用户信用分 + total_order_count INTEGER, -- 历史订单数 + total_order_amount REAL, -- 历史订单金额 + + -- product_info + product_id TEXT, -- 商品ID + product_code TEXT, -- 商品编码 + product_name TEXT, -- 商品名称 + product_short_name TEXT, -- 商品短名称 + product_type INTEGER, -- 商品类型 + product_status INTEGER, -- 商品状态 + category_id TEXT, -- 类目ID + category_name TEXT, -- 类目名称 + category_path TEXT, -- 类目路径 + brand_id TEXT, -- 品牌ID + brand_name TEXT, -- 品牌名称 + brand_english_name TEXT, -- 品牌英文名 + seller_id TEXT, -- 卖家ID + seller_name TEXT, -- 卖家名称 + seller_type INTEGER, -- 卖家类型 + shop_id TEXT, -- 店铺ID + shop_name TEXT, -- 店铺名称 + product_price REAL, -- 商品价格 + market_price REAL, -- 市场价 + cost_price REAL, -- 成本价 + wholesale_price REAL, -- 批发价 + product_quantity INTEGER, -- 商品数量 + product_unit TEXT, -- 商品单位 + product_weight REAL, -- 商品重量(克) + product_volume REAL, -- 商品体积(cm³) + product_spec TEXT, -- 商品规格 + product_color TEXT, -- 商品颜色 + product_size TEXT, -- 商品尺寸 + product_material TEXT, -- 商品材质 + product_origin TEXT, -- 商品产地 + product_shelf_life INTEGER, -- 保质期(天) + manufacture_date DATE, -- 生产日期 + expiry_date DATE, -- 过期日期 + batch_number TEXT, -- 批次号 + product_barcode TEXT, -- 商品条码 + warehouse_id TEXT, -- 发货仓库ID + warehouse_name TEXT, -- 发货仓库名称 + + -- address_info + receiver_name TEXT, -- 收货人姓名 + receiver_mobile TEXT, -- 收货人手机 + receiver_tel TEXT, -- 收货人电话 + receiver_email TEXT, -- 收货人邮箱 + receiver_country TEXT, -- 国家 + receiver_province TEXT, -- 省份 + receiver_city TEXT, -- 城市 + receiver_district TEXT, -- 区县 + receiver_street TEXT, -- 街道 + receiver_address TEXT, -- 详细地址 + receiver_zip TEXT, -- 邮编 + address_type INTEGER, -- 地址类型 + is_default INTEGER, -- 是否默认地址 + longitude REAL, -- 经度 + latitude REAL, -- 纬度 + address_label TEXT, -- 地址标签 + + -- shipping_info + shipping_type INTEGER, -- 配送方式 + shipping_method TEXT, -- 配送方式名称 + shipping_company TEXT, -- 快递公司 + shipping_company_code TEXT, -- 快递公司编码 + shipping_no TEXT, -- 快递单号 + shipping_time DATETIME, -- 发货时间 + shipping_remark TEXT, -- 发货备注 + expect_receive_time DATETIME, -- 预计送达时间 + receive_time DATETIME, -- 收货时间 + sign_type INTEGER, -- 签收类型 + shipping_status INTEGER, -- 物流状态 + tracking_url TEXT, -- 物流跟踪URL + is_free_shipping INTEGER, -- 是否包邮 + shipping_insurance REAL, -- 运费险金额 + shipping_distance REAL, -- 配送距离 + delivered_time DATETIME, -- 送达时间 + delivery_staff_id TEXT, -- 配送员ID + delivery_staff_name TEXT, -- 配送员姓名 + delivery_staff_mobile TEXT, -- 配送员电话 + + -- payment_info + payment_id TEXT, -- 支付ID + payment_no TEXT, -- 支付单号 + payment_type INTEGER, -- 支付方式 + payment_method TEXT, -- 支付方式名称 + payment_status INTEGER, -- 支付状态 + payment_platform TEXT, -- 支付平台 + transaction_id TEXT, -- 交易流水号 + payment_time DATETIME, -- 支付时间 + payment_account TEXT, -- 支付账号 + payment_bank TEXT, -- 支付银行 + payment_card_type TEXT, -- 支付卡类型 + payment_card_no TEXT, -- 支付卡号 + payment_scene TEXT, -- 支付场景 + payment_client_ip TEXT, -- 支付IP + payment_device TEXT, -- 支付设备 + payment_remark TEXT, -- 支付备注 + payment_voucher TEXT, -- 支付凭证 + + -- promotion_info + promotion_id TEXT, -- 活动ID + promotion_name TEXT, -- 活动名称 + promotion_type INTEGER, -- 活动类型 + promotion_desc TEXT, -- 活动描述 + promotion_start_time DATETIME, -- 活动开始时间 + promotion_end_time DATETIME, -- 活动结束时间 + coupon_id TEXT, -- 优惠券ID + coupon_code TEXT, -- 优惠券码 + coupon_type INTEGER, -- 优惠券类型 + coupon_name TEXT, -- 优惠券名称 + coupon_desc TEXT, -- 优惠券描述 + points_used INTEGER, -- 使用积分 + points_gained INTEGER, -- 获得积分 + points_multiple REAL, -- 积分倍率 + is_first_order INTEGER, -- 是否首单 + is_new_customer INTEGER, -- 是否新客 + marketing_channel TEXT, -- 营销渠道 + marketing_source TEXT, -- 营销来源 + referral_code TEXT, -- 推荐码 + referral_user_id TEXT, -- 推荐人ID + + -- after_sale_info + refund_id TEXT, -- 退款ID + refund_no TEXT, -- 退款单号 + refund_type INTEGER, -- 退款类型 + refund_status INTEGER, -- 退款状态 + refund_reason TEXT, -- 退款原因 + refund_desc TEXT, -- 退款描述 + refund_time DATETIME, -- 退款时间 + refund_amount REAL, -- 退款金额 + return_shipping_no TEXT, -- 退货快递单号 + return_shipping_company TEXT, -- 退货快递公司 + return_shipping_time DATETIME, -- 退货时间 + refund_evidence TEXT, -- 退款凭证 + complaint_id TEXT, -- 投诉ID + complaint_type INTEGER, -- 投诉类型 + complaint_status INTEGER, -- 投诉状态 + complaint_content TEXT, -- 投诉内容 + complaint_time DATETIME, -- 投诉时间 + complaint_handle_time DATETIME, -- 投诉处理时间 + complaint_handle_result TEXT, -- 投诉处理结果 + evaluation_score INTEGER, -- 评价分数 + evaluation_content TEXT, -- 评价内容 + evaluation_time DATETIME, -- 评价时间 + evaluation_reply TEXT, -- 评价回复 + evaluation_reply_time DATETIME, -- 评价回复时间 + evaluation_images TEXT, -- 评价图片 + evaluation_videos TEXT, -- 评价视频 + is_anonymous INTEGER, -- 是否匿名评价 + + -- invoice_info + invoice_type INTEGER, -- 发票类型 + invoice_title TEXT, -- 发票抬头 + invoice_content TEXT, -- 发票内容 + tax_no TEXT, -- 税号 + invoice_amount REAL, -- 发票金额 + invoice_status INTEGER, -- 发票状态 + invoice_time DATETIME, -- 开票时间 + invoice_number TEXT, -- 发票号码 + invoice_code TEXT, -- 发票代码 + company_name TEXT, -- 单位名称 + company_address TEXT, -- 单位地址 + company_tel TEXT, -- 单位电话 + company_bank TEXT, -- 开户银行 + company_account TEXT, -- 银行账号 + + -- delivery_time_info + expect_delivery_time DATETIME, -- 期望配送时间 + delivery_period_type INTEGER, -- 配送时段类型 + delivery_period_start TEXT, -- 配送时段开始 + delivery_period_end TEXT, -- 配送时段结束 + delivery_priority INTEGER, -- 配送优先级 + + -- tag_info + order_tags TEXT, -- 订单标签 + user_tags TEXT, -- 用户标签 + product_tags TEXT, -- 商品标签 + risk_level INTEGER, -- 风险等级 + risk_tags TEXT, -- 风险标签 + business_tags TEXT, -- 业务标签 + + -- commercial_info + gross_profit REAL, -- 毛利 + gross_profit_rate REAL, -- 毛利率 + settlement_amount REAL, -- 结算金额 + settlement_time DATETIME, -- 结算时间 + settlement_cycle INTEGER, -- 结算周期 + settlement_status INTEGER, -- 结算状态 + commission_rate REAL, -- 佣金比例 + platform_service_fee REAL, -- 平台服务费 + ad_cost REAL, -- 广告费用 + promotion_cost REAL -- 推广费用 +); + +-- 插入示例数据 +INSERT INTO order_wide_table ( + -- 基础订单信息 + order_id, order_no, order_type, order_status, create_time, order_source, + -- 订单金额 + original_amount, payment_amount, shipping_amount, + -- 用户信息 + user_id, user_name, user_level, mobile, + -- 商品信息 + product_id, product_name, product_quantity, product_price, + -- 收货信息 + receiver_name, receiver_mobile, receiver_address, + -- 物流信息 + shipping_no, shipping_status, + -- 支付信息 + payment_type, payment_status, + -- 营销信息 + promotion_id, coupon_amount, + -- 发票信息 + invoice_type, invoice_title +) VALUES +( + 'ORD20240101001', 'NO20240101001', 1, 2, '2024-01-01 10:00:00', 'APP', + 199.99, 188.88, 10.00, + 'USER001', '张三', 2, '13800138000', + 'PRD001', 'iPhone 15 手机壳', 2, 89.99, + '李四', '13900139000', '北京市朝阳区XX路XX号', + 'SF123456789', 1, + 1, 1, + 'PROM001', 20.00, + 1, '个人' +), +( + 'ORD20240101002', 'NO20240101002', 1, 1, '2024-01-01 11:00:00', 'H5', + 299.99, 279.99, 0.00, + 'USER002', '王五', 3, '13700137000', + 'PRD002', 'AirPods Pro 保护套', 1, 299.99, + '赵六', '13600136000', '上海市浦东新区XX路XX号', + 'YT987654321', 2, + 2, 2, + 'PROM002', 10.00, + 2, '上海科技有限公司' +), +( + 'ORD20240101003', 'NO20240101003', 2, 3, '2024-01-01 12:00:00', 'WEB', + 1999.99, 1899.99, 0.00, + 'USER003', '陈七', 4, '13500135000', + 'PRD003', 'MacBook Pro 电脑包', 1, 1999.99, + '孙八', '13400134000', '广州市天河区XX路XX号', + 'JD123123123', 3, + 3, 1, + 'PROM003', 100.00, + 1, '个人' +); diff --git a/examples/rag/db_schema_rag_example.py b/examples/rag/db_schema_rag_example.py index 1524634fa..7cfbf62d8 100644 --- a/examples/rag/db_schema_rag_example.py +++ b/examples/rag/db_schema_rag_example.py @@ -4,7 +4,8 @@ from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.rag.assembler import DBSchemaAssembler from dbgpt.rag.embedding import DefaultEmbeddingFactory -from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig +from dbgpt.serve.rag.connector import VectorStoreConnector +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig """DB struct rag example. pre-requirements: @@ -12,7 +13,7 @@ ``` embedding_model_path = "{your_embedding_model_path}" ``` - + Examples: ..code-block:: shell python examples/rag/db_schema_rag_example.py @@ -45,27 +46,26 @@ def _create_temporary_connection(): def _create_vector_connector(): """Create vector connector.""" - config = ChromaVectorConfig( - persist_path=PILOT_PATH, - name="dbschema_rag_test", + return VectorStoreConnector.from_default( + "Chroma", + vector_store_config=ChromaVectorConfig( + name="db_schema_vector_store_name", + persist_path=os.path.join(PILOT_PATH, "data"), + ), embedding_fn=DefaultEmbeddingFactory( default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), ) - return ChromaStore(config) - if __name__ == "__main__": connection = _create_temporary_connection() - index_store = _create_vector_connector() + vector_connector = _create_vector_connector() assembler = DBSchemaAssembler.load_from_connection( - connector=connection, - index_store=index_store, + connector=connection, table_vector_store_connector=vector_connector ) assembler.persist() # get db schema retriever retriever = assembler.as_retriever(top_k=1) chunks = retriever.retrieve("show columns from user") print(f"db schema rag example results:{[chunk.content for chunk in chunks]}") - index_store.delete_vector_name("dbschema_rag_test") diff --git a/scripts/examples/load_examples.sh b/scripts/examples/load_examples.sh index 0a829bdac..b01dedba3 100755 --- a/scripts/examples/load_examples.sh +++ b/scripts/examples/load_examples.sh @@ -15,6 +15,7 @@ fi DEFAULT_DB_FILE="DB-GPT/pilot/data/default_sqlite.db" DEFAULT_SQL_FILE="DB-GPT/docker/examples/sqls/*_sqlite.sql" DB_FILE="$WORK_DIR/pilot/data/default_sqlite.db" +WIDE_DB_FILE="$WORK_DIR/pilot/data/wide_sqlite.db" SQL_FILE="" usage () { @@ -61,6 +62,12 @@ if [ -n $SQL_FILE ];then sqlite3 $DB_FILE < "$file" done + for file in $WORK_DIR/docker/examples/sqls/*_sqlite_wide.sql + do + echo "execute sql file: $file" + sqlite3 $WIDE_DB_FILE < "$file" + done + else echo "Execute SQL file ${SQL_FILE}" sqlite3 $DB_FILE < $SQL_FILE