diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py index d1fc27334..774e915d4 100644 --- a/dbgpt/app/scene/base_chat.py +++ b/dbgpt/app/scene/base_chat.py @@ -513,7 +513,7 @@ def _generate_numbered_list(self) -> str: }, # {"response_data_text":" the default display method, suitable for single-line or simple content display"}, { - "response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc." + "response_scatter_chart": "Suitable for exploring relationships between variables, detecting outliers, etc." }, { "response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc." @@ -527,6 +527,9 @@ def _generate_numbered_list(self) -> str: { "response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc." }, + { + "response_vector_chart": "Suitable for projecting high-dimensional vector data onto a two-dimensional plot through the PCA algorithm." + }, ] return "\n".join( diff --git a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py index 3dcae4554..557ea71d3 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py +++ b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py @@ -3,6 +3,8 @@ import xml.etree.ElementTree as ET from typing import Dict, NamedTuple +import numpy as np +import pandas as pd import sqlparse from dbgpt._private.config import Config @@ -68,6 +70,52 @@ def parse_prompt_response(self, model_out_text): logger.error(f"json load failed:{clean_str}") return SqlAction("", clean_str, "", "") + def parse_vector_data_with_pca(self, df): + try: + from sklearn.decomposition import PCA + except ImportError: + raise ImportError( + "Could not import scikit-learn package. " + "Please install it with `pip install scikit-learn`." + ) + + nrow, ncol = df.shape + if nrow == 0 or ncol == 0: + return df, False + + vec_col = -1 + for i_col in range(ncol): + if isinstance(df.iloc[:, i_col][0], list): + vec_col = i_col + break + elif isinstance(df.iloc[:, i_col][0], bytes): + sample = df.iloc[:, i_col][0] + if isinstance(json.loads(sample.decode()), list): + vec_col = i_col + break + if vec_col == -1: + return df, False + vec_dim = len(json.loads(df.iloc[:, vec_col][0].decode())) + if min(nrow, vec_dim) < 2: + return df, False + df.iloc[:, vec_col] = df.iloc[:, vec_col].apply( + lambda x: json.loads(x.decode()) + ) + X = np.array(df.iloc[:, vec_col].tolist()) + + pca = PCA(n_components=2) + X_pca = pca.fit_transform(X) + + new_df = pd.DataFrame() + for i_col in range(ncol): + if i_col == vec_col: + continue + col_name = df.columns[i_col] + new_df[col_name] = df[col_name] + new_df["__x"] = [pos[0] for pos in X_pca] + new_df["__y"] = [pos[1] for pos in X_pca] + return new_df, True + def parse_view_response(self, speak, data, prompt_response) -> str: param = {} api_call_element = ET.Element("chart-view") @@ -83,6 +131,11 @@ def parse_view_response(self, speak, data, prompt_response) -> str: if prompt_response.sql: df = data(prompt_response.sql) param["type"] = prompt_response.display + + if param["type"] == "response_vector_chart": + df, visualizable = self.parse_vector_data_with_pca(df) + param["type"] = "response_scatter_chart" if visualizable else "response_table" + param["sql"] = prompt_response.sql param["data"] = json.loads( df.to_json(orient="records", date_format="iso", date_unit="s") diff --git a/dbgpt/datasource/rdbms/dialect/oceanbase/ob_dialect.py b/dbgpt/datasource/rdbms/dialect/oceanbase/ob_dialect.py index a4d920351..d6cc5dd08 100644 --- a/dbgpt/datasource/rdbms/dialect/oceanbase/ob_dialect.py +++ b/dbgpt/datasource/rdbms/dialect/oceanbase/ob_dialect.py @@ -1,12 +1,111 @@ """OB Dialect support.""" +import re + +from sqlalchemy import util from sqlalchemy.dialects import registry from sqlalchemy.dialects.mysql import pymysql +from sqlalchemy.dialects.mysql.reflection import MySQLTableDefinitionParser, _re_compile + + +class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser): + """OceanBase table definition parser.""" + + def __init__(self, dialect, preparer, *, default_schema=None): + """Initialize OceanBaseTableDefinitionParser.""" + MySQLTableDefinitionParser.__init__(self, dialect, preparer) + self.default_schema = default_schema + + def _prep_regexes(self): + super()._prep_regexes() + + _final = self.preparer.final_quote + quotes = dict( + zip( + ("iq", "fq", "esc_fq"), + [ + re.escape(s) + for s in ( + self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final), + ) + ], + ) + ) + + self._re_key = _re_compile( + r" " + r"(?:(SPATIAL|VECTOR|(?P\S+)) )?KEY" + # r"(?:(?P\S+) )?KEY" + r"(?: +{iq}(?P(?:{esc_fq}|[^{fq}])+){fq})?" + r"(?: +USING +(?P\S+))?" + r" +\((?P.+?)\)" + r"(?: +USING +(?P\S+))?" + r"(?: +(KEY_)?BLOCK_SIZE *[ =]? *(?P\S+) *(LOCAL)?)?" + r"(?: +WITH PARSER +(?P\S+))?" + r"(?: +COMMENT +(?P(\x27\x27|\x27([^\x27])*?\x27)+))?" + r"(?: +/\*(?P.+)\*/ *)?" + r",?$".format(iq=quotes["iq"], esc_fq=quotes["esc_fq"], fq=quotes["fq"]) + ) + + kw = quotes.copy() + kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION" + self._re_fk_constraint = _re_compile( + r" " + r"CONSTRAINT +" + r"{iq}(?P(?:{esc_fq}|[^{fq}])+){fq} +" + r"FOREIGN KEY +" + r"\((?P[^\)]+?)\) REFERENCES +" + r"(?P{iq}[^{fq}]+{fq}" + r"(?:\.{iq}[^{fq}]+{fq})?) *" + r"\((?P(?:{iq}[^{fq}]+{fq}(?: *, *)?)+)\)" + r"(?: +(?PMATCH \w+))?" + r"(?: +ON UPDATE (?P{on}))?" + r"(?: +ON DELETE (?P{on}))?".format( + iq=quotes["iq"], esc_fq=quotes["esc_fq"], fq=quotes["fq"], on=kw["on"] + ) + ) + + def _parse_constraints(self, line): + """Parse a CONSTRAINT line.""" + ret = super()._parse_constraints(line) + if ret: + tp, spec = ret + if tp == "partition": + # do not handle partition + return ret + # logger.info(f"{tp} {spec}") + if ( + tp == "fk_constraint" + and len(spec["table"]) == 2 + and spec["table"][0] == self.default_schema + ): + spec["table"] = spec["table"][1:] + if spec.get("onupdate", "").lower() == "restrict": + spec["onupdate"] = None + if spec.get("ondelete", "").lower() == "restrict": + spec["ondelete"] = None + return ret class OBDialect(pymysql.MySQLDialect_pymysql): """OBDialect expend.""" + supports_statement_cache = True + + def __init__(self, **kwargs): + """Initialize OBDialect.""" + try: + from pyobvector import VECTOR # type: ignore + except ImportError: + raise ImportError( + "Could not import pyobvector package. " + "Please install it with `pip install pyobvector`." + ) + super().__init__(**kwargs) + self.ischema_names["VECTOR"] = VECTOR + def initialize(self, connection): """Ob dialect initialize.""" super(OBDialect, self).initialize(connection) @@ -22,5 +121,18 @@ def get_isolation_level(self, dbapi_connection): self.server_version_info = (5, 7, 19) return super(OBDialect, self).get_isolation_level(dbapi_connection) + @util.memoized_property + def _tabledef_parser(self): + """Return the MySQLTableDefinitionParser, generate if needed. + + The deferred creation ensures that the dialect has + retrieved server version information first. + """ + preparer = self.identifier_preparer + default_schema = self.default_schema_name + return OceanBaseTableDefinitionParser( + self, preparer, default_schema=default_schema + ) + registry.register("mysql.ob", __name__, "OBDialect") diff --git a/dbgpt/vis/tags/vis_chart.py b/dbgpt/vis/tags/vis_chart.py index 3c62e2cb4..54402bdb8 100644 --- a/dbgpt/vis/tags/vis_chart.py +++ b/dbgpt/vis/tags/vis_chart.py @@ -24,7 +24,7 @@ def default_chart_type_prompt() -> str: "non-numeric columns" }, { - "response_scatter_plot": "Suitable for exploring relationships between " + "response_scatter_chart": "Suitable for exploring relationships between " "variables, detecting outliers, etc." }, {