Skip to content

Commit

Permalink
feat: Added support for TuGraph graph database (eosphoros-ai#1451)
Browse files Browse the repository at this point in the history
Co-authored-by: aries_ckt <[email protected]>
  • Loading branch information
KingSkyLi and Aries-ckt authored Apr 26, 2024
1 parent 98ebfdc commit a5666b3
Show file tree
Hide file tree
Showing 30 changed files with 379 additions and 38 deletions.
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ ignore_missing_imports = True
[mypy-fastchat.protocol.api_protocol]
ignore_missing_imports = True

[mypy-neo4j.*]
ignore_missing_imports = True

# Agent
[mypy-seaborn.*]
ignore_missing_imports = True
Expand Down
18 changes: 12 additions & 6 deletions dbgpt/app/scene/chat_db/professional_qa/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@ def __init__(self, chat_param: Dict):
if self.db_name:
self.database = CFG.local_db_manager.get_connector(self.db_name)
self.tables = self.database.get_table_names()

self.top_k = (
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
if len(self.tables) > CFG.KNOWLEDGE_SEARCH_TOP_SIZE
else len(self.tables)
)
if self.database.is_graph_type():
# When the current graph database retrieves source data from ChatDB, the topk uses the sum of node table and edge table.
self.top_k = len(self.tables["vertex_tables"]) + len(
self.tables["edge_tables"]
)
else:
print(self.database.db_type)
self.top_k = (
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
if len(self.tables) > CFG.KNOWLEDGE_SEARCH_TOP_SIZE
else len(self.tables)
)

@trace()
async def generate_input_values(self) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/static/404.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/404/index.html

Large diffs are not rendered by default.

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/agent/index.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/app/index.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/chat/index.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/database/index.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/flow/canvas/index.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/flow/index.html

Large diffs are not rendered by default.

Binary file added dbgpt/app/static/icons/tugraph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion dbgpt/app/static/index.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/knowledge/chunk/index.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/knowledge/index.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/models/index.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dbgpt/app/static/prompt/index.html

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions dbgpt/datasource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,8 @@ def get_indexes(self, table_name: str) -> List[Dict]:
def is_normal_type(cls) -> bool:
"""Return whether the connector is a normal type."""
return True

@classmethod
def is_graph_type(cls) -> bool:
"""Return whether the connector is a graph database connector."""
return False
127 changes: 127 additions & 0 deletions dbgpt/datasource/conn_tugraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""TuGraph Connector."""
import json
from typing import Any, Dict, List, cast

from .base import BaseConnector


class TuGraphConnector(BaseConnector):
"""TuGraph connector."""

db_type: str = "tugraph"
driver: str = "bolt"
dialect: str = "tugraph"

def __init__(self, session):
"""Initialize the connector with a Neo4j driver."""
self._session = session
self._schema = None

@classmethod
def from_uri_db(
cls, host: str, port: int, user: str, pwd: str, db_name: str, **kwargs: Any
) -> "TuGraphConnector":
"""Create a new TuGraphConnector from host, port, user, pwd, db_name."""
try:
from neo4j import GraphDatabase

db_url = f"{cls.driver}://{host}:{str(port)}"
with GraphDatabase.driver(db_url, auth=(user, pwd)) as client:
client.verify_connectivity()
session = client.session(database=db_name)
return cast(TuGraphConnector, cls(session=session))
except ImportError as err:
raise ImportError("requests package is not installed") from err

def get_table_names(self) -> Dict[str, List[str]]:
"""Get all table names from the TuGraph database using the Neo4j driver."""
# Run the query to get vertex labels
v_result = self._session.run("CALL db.vertexLabels()").data()
v_data = [table_name["label"] for table_name in v_result]

# Run the query to get edge labels
e_result = self._session.run("CALL db.edgeLabels()").data()
e_data = [table_name["label"] for table_name in e_result]
return {"vertex_tables": v_data, "edge_tables": e_data}

def get_grants(self):
"""Get grants."""
return []

def get_collation(self):
"""Get collation."""
return "UTF-8"

def get_charset(self):
"""Get character_set of current database."""
return "UTF-8"

def table_simple_info(self):
"""Get table simple info."""
return []

def close(self):
"""Close the Neo4j driver."""
self._session.close()

def run(self):
"""Run GQL."""
return []

def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
"""Get fields about specified graph.
Args:
table_name (str): table name (graph name)
table_type (str): table type (vertex or edge)
Returns:
columns: List[Dict], which contains name: str, type: str,
default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'int', 'default_expression': '',
'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
data = []
result = None
if table_type == "vertex":
result = self._session.run(
f"CALL db.getVertexSchema('{table_name}')"
).data()
else:
result = self._session.run(f"CALL db.getEdgeSchema('{table_name}')").data()
schema_info = json.loads(result[0]["schema"])
for prop in schema_info.get("properties", []):
prop_dict = {
"name": prop["name"],
"type": prop["type"],
"default_expression": "",
"is_in_primary_key": bool(
"primary" in schema_info and prop["name"] == schema_info["primary"]
),
"comment": prop["name"],
}
data.append(prop_dict)
return data

def get_indexes(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
"""Get table indexes about specified table.
Args:
table_name:(str) table name
table_type:(str)'vertex' | 'edge'
Returns:
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""
# [{'name':'id','column_names':['id']}]
result = self._session.run(
f"CALL db.listLabelIndexes('{table_name}','{table_type}')"
).data()
transformed_data = []
for item in result:
new_dict = {"name": item["field"], "column_names": [item["field"]]}
transformed_data.append(new_dict)
return transformed_data

@classmethod
def is_graph_type(cls) -> bool:
"""Return whether the connector is a graph database connector."""
return True
1 change: 1 addition & 0 deletions dbgpt/datasource/manages/connector_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def on_init(self):
Load all connector classes.
"""
from dbgpt.datasource.conn_spark import SparkConnector # noqa: F401
from dbgpt.datasource.conn_tugraph import TuGraphConnector # noqa: F401
from dbgpt.datasource.rdbms.base import RDBMSConnector # noqa: F401
from dbgpt.datasource.rdbms.conn_clickhouse import ( # noqa: F401
ClickhouseConnector,
Expand Down
7 changes: 6 additions & 1 deletion dbgpt/rag/knowledge/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dbgpt.core import Document
from dbgpt.datasource import BaseConnector

from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary
from ..summary.rdbms_db_summary import _parse_db_summary
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType

Expand Down Expand Up @@ -34,7 +35,11 @@ def __init__(
def _load(self) -> List[Document]:
"""Load datasource document from data_loader."""
docs = []
for table_summary in _parse_db_summary(self._connector, self._summary_template):
if self._connector.is_graph_type():
db_summary = _parse_gdb_summary(self._connector, self._summary_template)
else:
db_summary = _parse_db_summary(self._connector, self._summary_template)
for table_summary in db_summary:
metadata = {"source": "database"}
if self._metadata:
metadata.update(self._metadata) # type: ignore
Expand Down
17 changes: 16 additions & 1 deletion dbgpt/rag/summary/db_summary_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dbgpt._private.config import Config
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(self, system_app: SystemApp):

def db_summary_embedding(self, dbname, db_type):
"""Put db profile and table profile summary into vector store."""
db_summary_client = RdbmsSummary(dbname, db_type)
db_summary_client = self.create_summary_client(dbname, db_type)

self.init_db_profile(db_summary_client, dbname)

Expand Down Expand Up @@ -122,3 +123,17 @@ def delete_db_profile(self, dbname):
)
vector_connector.delete_vector_name(vector_store_name)
logger.info(f"delete db profile {dbname} success")

@staticmethod
def create_summary_client(dbname: str, db_type: str):
"""
Create a summary client based on the database type.
Args:
dbname (str): The name of the database.
db_type (str): The type of the database.
"""
if "graph" in db_type:
return GdbmsSummary(dbname, db_type)
else:
return RdbmsSummary(dbname, db_type)
134 changes: 134 additions & 0 deletions dbgpt/rag/summary/gdbms_db_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Summary for rdbms database."""

from typing import TYPE_CHECKING, Dict, List, Optional

from dbgpt._private.config import Config
from dbgpt.datasource import BaseConnector
from dbgpt.datasource.conn_tugraph import TuGraphConnector
from dbgpt.rag.summary.db_summary import DBSummary

if TYPE_CHECKING:
from dbgpt.datasource.manages import ConnectorManager

CFG = Config()


class GdbmsSummary(DBSummary):
"""Get graph db table summary template."""

def __init__(
self, name: str, type: str, manager: Optional["ConnectorManager"] = None
):
"""Create a new RdbmsSummary."""
self.name = name
self.type = type
self.summary_template = "{table_name}({columns})"
# self.v_summary_template = "{table_name}({columns})"
self.tables = {}
# self.tables_info = []
# self.vector_tables_info = []

# TODO: Don't use the global variable.
db_manager = manager or CFG.local_db_manager
if not db_manager:
raise ValueError("Local db manage is not initialized.")
self.db = db_manager.get_connector(name)

self.metadata = """user info :{users}, grant info:{grant}, charset:{charset},
collation:{collation}""".format(
users=self.db.get_users(),
grant=self.db.get_grants(),
charset=self.db.get_charset(),
collation=self.db.get_collation(),
)
tables = self.db.get_table_names()
self.table_info_summaries = {
"vertex_tables": [
self.get_table_summary(table_name, "vertex")
for table_name in tables["vertex_tables"]
],
"edge_tables": [
self.get_table_summary(table_name, "edge")
for table_name in tables["edge_tables"]
],
}

def get_table_summary(self, table_name, table_type):
"""Get table summary for table.
example:
table_name(column1(column1 comment),column2(column2 comment),
column3(column3 comment) and index keys, and table comment: {table_comment})
"""
return _parse_table_summary(
self.db, self.summary_template, table_name, table_type
)

def table_summaries(self):
"""Get table summaries."""
return self.table_info_summaries


def _parse_db_summary(
conn: BaseConnector, summary_template: str = "{table_name}({columns})"
) -> List[str]:
"""Get db summary for database."""
table_info_summaries = None
if isinstance(conn, TuGraphConnector):
table_names = conn.get_table_names()
v_tables = table_names.get("vertex_tables", [])
e_tables = table_names.get("edge_tables", [])
table_info_summaries = [
_parse_table_summary(conn, summary_template, table_name, "vertex")
for table_name in v_tables
] + [
_parse_table_summary(conn, summary_template, table_name, "edge")
for table_name in e_tables
]
else:
table_info_summaries = []

return table_info_summaries


def _format_column(column: Dict) -> str:
"""Format a single column's summary."""
comment = column.get("comment", "")
if column.get("is_in_primary_key"):
comment += " Primary Key" if comment else "Primary Key"
return f"{column['name']} ({comment})" if comment else column["name"]


def _format_indexes(indexes: List[Dict]) -> str:
"""Format index keys for table summary."""
return ", ".join(
f"{index['name']}(`{', '.join(index['column_names'])}`)" for index in indexes
)


def _parse_table_summary(
conn: TuGraphConnector, summary_template: str, table_name: str, table_type: str
) -> str:
"""Enhanced table summary function."""
columns = [
_format_column(column) for column in conn.get_columns(table_name, table_type)
]
column_str = ", ".join(columns)

indexes = conn.get_indexes(table_name, table_type)
index_str = _format_indexes(indexes) if indexes else ""

table_str = summary_template.format(table_name=table_name, columns=column_str)
if index_str:
table_str += f", and index keys: {index_str}"
try:
comment = conn.get_table_comment(table_name)
except Exception:
comment = dict(text=None)
if comment.get("text"):
table_str += (
f", and table comment: {comment.get('text')}, this is a {table_type} table"
)
else:
table_str += f", and table comment: this is a {table_type} table"
return table_str
1 change: 1 addition & 0 deletions dbgpt/storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class DBType(Enum):
Spark = DbInfo("spark", True)
Doris = DbInfo("doris")
Hive = DbInfo("hive")
TuGraph = DbInfo("tugraph")

def value(self) -> str:
"""Return the name of the database type."""
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def all_datasource_requires():
"pyhive",
"thrift",
"thrift_sasl",
"neo4j",
]


Expand Down
Loading

0 comments on commit a5666b3

Please sign in to comment.