-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: dongzhancai1 <[email protected]> Co-authored-by: dong <[email protected]>
- Loading branch information
1 parent
7f4b5e7
commit 9b0161e
Showing
17 changed files
with
948 additions
and
243 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,76 +1,117 @@ | ||
from unittest.mock import MagicMock | ||
from typing import List | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector | ||
from dbgpt.rag.assembler.embedding import EmbeddingAssembler | ||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType | ||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory | ||
from dbgpt.rag.knowledge.base import Knowledge | ||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter | ||
from dbgpt.storage.vector_store.chroma_store import ChromaStore | ||
import dbgpt | ||
from dbgpt.core import Chunk | ||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever | ||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary | ||
|
||
|
||
@pytest.fixture | ||
def mock_db_connection(): | ||
"""Create a temporary database connection for testing.""" | ||
connect = SQLiteTempConnector.create_temporary_db() | ||
connect.create_temp_tables( | ||
{ | ||
"user": { | ||
"columns": { | ||
"id": "INTEGER PRIMARY KEY", | ||
"name": "TEXT", | ||
"age": "INTEGER", | ||
}, | ||
"data": [ | ||
(1, "Tom", 10), | ||
(2, "Jerry", 16), | ||
(3, "Jack", 18), | ||
(4, "Alice", 20), | ||
(5, "Bob", 22), | ||
], | ||
} | ||
} | ||
) | ||
return connect | ||
return MagicMock() | ||
|
||
|
||
@pytest.fixture | ||
def mock_chunk_parameters(): | ||
return MagicMock(spec=ChunkParameters) | ||
def mock_table_vector_store_connector(): | ||
mock_connector = MagicMock() | ||
mock_connector.vector_store_config.name = "table_name" | ||
chunk = Chunk( | ||
content="table_name: user\ncomment: user about dbgpt", | ||
metadata={ | ||
"field_num": 6, | ||
"part": "table", | ||
"separated": 1, | ||
"table_name": "user", | ||
}, | ||
) | ||
mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk]) | ||
return mock_connector | ||
|
||
|
||
@pytest.fixture | ||
def mock_embedding_factory(): | ||
return MagicMock(spec=EmbeddingFactory) | ||
def mock_field_vector_store_connector(): | ||
mock_connector = MagicMock() | ||
chunk1 = Chunk( | ||
content="name,age", | ||
metadata={ | ||
"field_num": 6, | ||
"part": "field", | ||
"part_index": 0, | ||
"separated": 1, | ||
"table_name": "user", | ||
}, | ||
) | ||
chunk2 = Chunk( | ||
content="address,gender", | ||
metadata={ | ||
"field_num": 6, | ||
"part": "field", | ||
"part_index": 1, | ||
"separated": 1, | ||
"table_name": "user", | ||
}, | ||
) | ||
chunk3 = Chunk( | ||
content="mail,phone", | ||
metadata={ | ||
"field_num": 6, | ||
"part": "field", | ||
"part_index": 2, | ||
"separated": 1, | ||
"table_name": "user", | ||
}, | ||
) | ||
mock_connector.similar_search_with_scores = MagicMock( | ||
return_value=[chunk1, chunk2, chunk3] | ||
) | ||
return mock_connector | ||
|
||
|
||
@pytest.fixture | ||
def mock_vector_store_connector(): | ||
return MagicMock(spec=ChromaStore) | ||
def dbstruct_retriever( | ||
mock_db_connection, | ||
mock_table_vector_store_connector, | ||
mock_field_vector_store_connector, | ||
): | ||
return DBSchemaRetriever( | ||
connector=mock_db_connection, | ||
table_vector_store_connector=mock_table_vector_store_connector, | ||
field_vector_store_connector=mock_field_vector_store_connector, | ||
separator="--table-field-separator--", | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mock_knowledge(): | ||
return MagicMock(spec=Knowledge) | ||
def mock_parse_db_summary() -> str: | ||
"""Patch _parse_db_summary method.""" | ||
return ( | ||
"table_name: user\ncomment: user about dbgpt\n" | ||
"--table-field-separator--\n" | ||
"name,age\naddress,gender\nmail,phone" | ||
) | ||
|
||
|
||
def test_load_knowledge( | ||
mock_db_connection, | ||
mock_knowledge, | ||
mock_chunk_parameters, | ||
mock_embedding_factory, | ||
mock_vector_store_connector, | ||
): | ||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" | ||
mock_chunk_parameters.text_splitter = CharacterTextSplitter() | ||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE | ||
assembler = EmbeddingAssembler( | ||
knowledge=mock_knowledge, | ||
chunk_parameters=mock_chunk_parameters, | ||
embeddings=mock_embedding_factory.create(), | ||
index_store=mock_vector_store_connector, | ||
# Mocking the _parse_db_summary method in your test function | ||
@patch.object( | ||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary | ||
) | ||
def test_retrieve_with_mocked_summary(dbstruct_retriever): | ||
query = "Table summary" | ||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query) | ||
assert isinstance(chunks[0], Chunk) | ||
assert chunks[0].content == ( | ||
"table_name: user\ncomment: user about dbgpt\n" | ||
"--table-field-separator--\n" | ||
"name,age\naddress,gender\nmail,phone" | ||
) | ||
|
||
|
||
def async_mock_parse_db_summary() -> str: | ||
"""Asynchronous patch for _parse_db_summary method.""" | ||
return ( | ||
"table_name: user\ncomment: user about dbgpt\n" | ||
"--table-field-separator--\n" | ||
"name,age\naddress,gender\nmail,phone" | ||
) | ||
assembler.load_knowledge(knowledge=mock_knowledge) | ||
assert len(assembler._chunks) == 0 |
Oops, something went wrong.