diff --git a/pyproject.toml b/pyproject.toml index 61637e3..777313c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "sqlgpt-parser" -version = "0.0.1a4" +version = "0.0.1a5" authors = [ { name="luliwjc", email="chenxiaoxi_wjc@163.com" }, { name="Ifffff", email="tingkai.ztk@antgroup.com" }, diff --git a/sqlgpt_parser/format/formatter.py b/sqlgpt_parser/format/formatter.py index 84a5c71..c201602 100644 --- a/sqlgpt_parser/format/formatter.py +++ b/sqlgpt_parser/format/formatter.py @@ -424,10 +424,13 @@ def visit_list_expression(self, node, unmangle_names): def visit_window_func(self, node, unmangle_names): args = ", ".join([self.process(arg, unmangle_names) for arg in node.func_args]) ignore_null = f" {node.ignore_null} NULLS" if node.ignore_null else "" - window_spec = " OVER (" + self.process(node.window_spec, unmangle_names) + ")" + window_spec = " OVER " + self.process(node.window_spec, unmangle_names) return f"{node.func_name.upper()}({args}){ignore_null}{window_spec}" def visit_window_spec(self, node, unmangle_names): + if node.window_name is not None: + return node.window_name + parts = [] if node.partition_by: self.process(node.partition_by, unmangle_names) @@ -435,8 +438,7 @@ def visit_window_spec(self, node, unmangle_names): parts.append("ORDER BY " + format_sort_items(node.order_by, unmangle_names)) if node.frame_clause: parts.append(self.process(node.frame_clause, unmangle_names)) - - return ' '.join(parts) + return '(' + ' '.join(parts) + ')' def visit_partition_by_clause(self, node, unmangle_names): return "PARTITION BY " + self._join_expressions(node.items, unmangle_names) diff --git a/sqlgpt_parser/parser/mysql_parser/parser.py b/sqlgpt_parser/parser/mysql_parser/parser.py index 9d1545b..b5240ca 100644 --- a/sqlgpt_parser/parser/mysql_parser/parser.py +++ b/sqlgpt_parser/parser/mysql_parser/parser.py @@ -1206,7 +1206,7 @@ def p_alias_opt(p): if p.slice[1].type == "alias": p[0] = p[1] else: - p[0] = () + p[0] = [] def p_alias(p): @@ -1215,9 +1215,9 @@ def p_alias(p): | AS string_lit | string_lit""" if len(p) == 3: - p[0] = (p[1], p[2]) + p[0] = [p[1], p[2]] else: - p[0] = p[1] + p[0] = [p[1]] def p_expression(p): diff --git a/sqlgpt_parser/parser/oceanbase_parser/parser.py b/sqlgpt_parser/parser/oceanbase_parser/parser.py index f113d48..a02ee17 100644 --- a/sqlgpt_parser/parser/oceanbase_parser/parser.py +++ b/sqlgpt_parser/parser/oceanbase_parser/parser.py @@ -1203,7 +1203,7 @@ def p_alias_opt(p): if p.slice[1].type == "alias": p[0] = p[1] else: - p[0] = () + p[0] = [] def p_alias(p): @@ -1212,9 +1212,9 @@ def p_alias(p): | AS string_lit | string_lit""" if len(p) == 3: - p[0] = (p[1], p[2]) + p[0] = [p[1], p[2]] else: - p[0] = p[1] + p[0] = [p[1]] def p_expression(p): diff --git a/sqlgpt_parser/parser/odps_parser/parser.py b/sqlgpt_parser/parser/odps_parser/parser.py index 4b30468..f92f871 100644 --- a/sqlgpt_parser/parser/odps_parser/parser.py +++ b/sqlgpt_parser/parser/odps_parser/parser.py @@ -1205,7 +1205,7 @@ def p_alias_opt(p): if p.slice[1].type == "alias": p[0] = p[1] else: - p[0] = () + p[0] = [] def p_alias(p): @@ -1214,7 +1214,7 @@ def p_alias(p): | AS string_lit | string_lit""" if len(p) == 3: - p[0] = (p[1], p[2]) + p[0] = [p[1], p[2]] else: p[0] = p[1] diff --git a/sqlgpt_parser/parser/parser_utils.py b/sqlgpt_parser/parser/parser_utils.py index 80d6695..5356196 100644 --- a/sqlgpt_parser/parser/parser_utils.py +++ b/sqlgpt_parser/parser/parser_utils.py @@ -12,10 +12,8 @@ from sqlgpt_parser.parser.tree.grouping import GroupingSets, SimpleGroupBy from sqlgpt_parser.parser.tree.literal import StringLiteral -from sqlgpt_parser.parser.tree.select_item import SingleColumn from sqlgpt_parser.parser.tree.visitor import DefaultTraversalVisitor from sqlgpt_parser.parser.tree.expression import ( - FunctionCall, InListExpression, QualifiedNameReference, SubqueryExpression, @@ -23,6 +21,13 @@ class ParserUtils(object): + class CollectInfo: + COLLECT_FILTER_COLUMN = 1 + COLLECT_PROJECT_COLUMN = 2 + COLLECT_TABLE = 4 + COLLECT_MIN_MAX_EXPRESSION_COLUMN = 8 + COLLECT_IN_EXPRESSION_COLUMN = 16 + @staticmethod def format_statement(statement): class FormatVisitor(DefaultTraversalVisitor): @@ -50,32 +55,49 @@ def __init__(self): self.limit_number = 0 self.recursion_count = 0 - def visit_table(self, node, context): + def add_project_column(self, project_column): + self.projection_column_list.append(project_column) + + def add_table(self, table_name, alias=''): self.table_list.append( - { - 'table_name': node.name.parts[0] - if len(node.name.parts) == 1 - else node.name.parts[1], - 'alias': '', - 'filter_column_list': [], - } + {'table_name': table_name, 'alias': alias, 'filter_column_list': []} + ) + + def add_filter_column( + self, filter_col, compare_type, table_or_alias_name=None + ): + filter_column_list = None + if table_or_alias_name is not None: + for table in self.table_list: + if ( + table['alias'] == table_or_alias_name + or table['table_name'] == table_or_alias_name + ): + filter_column_list = table['filter_column_list'] + else: + filter_column_list = self.table_list[-1]['filter_column_list'] + filter_column_list.append( + {"column_name": filter_col, 'opt': compare_type} ) + + def visit_table(self, node, context): + if context & ParserUtils.CollectInfo.COLLECT_TABLE: + table_name = node.name.parts[-1] + self.add_table(table_name) return self.visit_query_body(node, context) def visit_aliased_relation(self, node, context): alias = "" if len(node.alias) == 2: alias = node.alias[1] - elif len(node.alias) == 1: + else: alias = node.alias[0] - if not isinstance(node.relation, SubqueryExpression): - self.table_list.append( - { - 'table_name': node.relation.name.parts[0], - 'alias': alias, - 'filter_column_list': [], - } - ) + if ( + not isinstance(node.relation, SubqueryExpression) + and context & ParserUtils.CollectInfo.COLLECT_TABLE + ): + table_name = node.relation.name.parts[-1] + self.add_table(table_name, alias) else: return self.process(node.relation, context) @@ -92,36 +114,23 @@ def visit_logical_binary_expression(self, node, context): def visit_comparison_expression(self, node, context): left = node.left right = node.right - type = node.type - qualified_name_list = [] + + def add_filter_column(name): + table_name = None + if len(name.parts) > 2: + table_name = name.parts[-2] + self.add_filter_column(name.parts[-1], node.type, table_name) if isinstance(right, QualifiedNameReference): - qualified_name_list.append(right.name) + if context & ParserUtils.CollectInfo.COLLECT_FILTER_COLUMN: + add_filter_column(right.name) + else: + self.process(node.right, context) if isinstance(left, QualifiedNameReference): - qualified_name_list.append(left.name) - - for qualified_name in qualified_name_list: - if len(qualified_name.parts) == 2: - table_or_alias_name = qualified_name.parts[0] - for _table in self.table_list: - if ( - _table['alias'] == table_or_alias_name - or _table['table_name'] == table_or_alias_name - ): - filter_column_list = _table['filter_column_list'] - filter_column_list.append( - { - 'column_name': qualified_name.parts[1], - 'opt': type, - } - ) - else: - filter_column_list = self.table_list[-1]['filter_column_list'] - filter_column_list.append( - {'column_name': qualified_name.parts[0], 'opt': type} - ) - - return self.visit_expression(node, context) + if context & ParserUtils.CollectInfo.COLLECT_FILTER_COLUMN: + add_filter_column(left.name) + else: + self.process(node.left, context) def visit_like_predicate(self, node, context): if isinstance(node.value, QualifiedNameReference): @@ -130,12 +139,16 @@ def visit_like_predicate(self, node, context): if isinstance(pattern, StringLiteral): if not pattern.value.startswith('%'): can_query_range = True - if can_query_range: - self.add_filter_column_with_qualified_name_reference( - node.value, 'like' - ) + if ( + can_query_range + and context & ParserUtils.CollectInfo.COLLECT_FILTER_COLUMN + ): + value, table_name = node.value, None + if len(value.name.parts) > 2: + table_name = value.name.parts[-2] + self.add_filter_column(value.name.parts[-1], 'like', table_name) - return self.visit_expression(node, context) + self.process(node.pattern, context) def visit_not_expression(self, node, context): return self.process(node.value, "not") @@ -144,57 +157,45 @@ def visit_in_predicate(self, node, context): value = node.value if not node.is_not: - if isinstance(node.value_list, InListExpression): + if ( + isinstance(node.value_list, InListExpression) + and context + & ParserUtils.CollectInfo.COLLECT_IN_EXPRESSION_COLUMN + ): self.in_count_list.append(len(node.value_list.values)) - if isinstance(value, QualifiedNameReference): - self.add_filter_column_with_qualified_name_reference( - value, 'in' - ) + if ( + isinstance(value, QualifiedNameReference) + and context & ParserUtils.CollectInfo.COLLECT_FILTER_COLUMN + ): + table_name = None + if len(value.name.parts) > 2: + table_name = value.name.parts[-2] + self.add_filter_column(value.name.parts[-1], 'in', table_name) - self.process(node.value, None) - self.process(node.value_list, None) + self.process(node.value, context) + self.process(node.value_list, context) return None def visit_select(self, node, context): for item in node.select_items: - if isinstance(item, SingleColumn): - expression = item.expression - if isinstance(expression, QualifiedNameReference): - name = expression.name - if len(name.parts) == 2: - self.projection_column_list.append(name.parts[1]) - else: - self.projection_column_list.append(name.parts[0]) - if isinstance(expression, FunctionCall): - arguments = expression.arguments - if len(arguments) > 0: - for argument in arguments: - if isinstance(argument, QualifiedNameReference): - name = argument.name - _column_name = '' - if len(name.parts) == 2: - _column_name = name.parts[1] - else: - _column_name = name.parts[0] - - if expression.name == 'max': - self.min_max_list.append(_column_name) - - self.projection_column_list.append(_column_name) - - if argument == "*": - name = expression.name - if name == 'count': - self.projection_column_list.append( - 'count(*)' - ) - - else: - name = expression.name - if name == 'count': - self.projection_column_list.append('count(*)') self.process(item, context) + def visit_qualified_name_reference(self, node, context): + if context & ParserUtils.CollectInfo.COLLECT_PROJECT_COLUMN: + self.add_project_column(node.name.parts[-1]) + + def visit_aggregate_func(self, node, context): + if node.name == "count" and node.arguments[0] == "*": + if context & ParserUtils.CollectInfo.COLLECT_PROJECT_COLUMN: + self.add_project_column("count(*)") + else: + for arg in node.arguments: + self.process(arg, context) + if context & ParserUtils.CollectInfo.COLLECT_MIN_MAX_EXPRESSION_COLUMN: + if node.name == 'max' or node.name == 'min': + # min or max only has one argument + self.min_max_list.append(node.arguments[0]) + def visit_sort_item(self, node, context): sort_key = node.sort_key ordering = node.ordering @@ -212,9 +213,18 @@ def visit_sort_item(self, node, context): def visit_query_specification(self, node, context): self.limit_number = node.limit + context = ( + ParserUtils.CollectInfo.COLLECT_PROJECT_COLUMN + | ParserUtils.CollectInfo.COLLECT_MIN_MAX_EXPRESSION_COLUMN + ) self.process(node.select, context) + context = ParserUtils.CollectInfo.COLLECT_TABLE if node.from_: self.process(node.from_, context) + context = ( + ParserUtils.CollectInfo.COLLECT_IN_EXPRESSION_COLUMN + | ParserUtils.CollectInfo.COLLECT_FILTER_COLUMN + ) if node.where: self.process(node.where, context) if node.group_by: @@ -233,6 +243,10 @@ def visit_query_specification(self, node, context): def visit_update(self, node, context): table_list = node.table + context = ( + ParserUtils.CollectInfo.COLLECT_TABLE + | ParserUtils.CollectInfo.COLLECT_FILTER_COLUMN + ) if table_list: for _table in table_list: self.process(_table, context) @@ -242,6 +256,10 @@ def visit_update(self, node, context): def visit_delete(self, node, context): table_list = node.table + context = ( + ParserUtils.CollectInfo.COLLECT_TABLE + | ParserUtils.CollectInfo.COLLECT_FILTER_COLUMN + ) if table_list: for _table in table_list: self.process(_table, context) @@ -250,10 +268,13 @@ def visit_delete(self, node, context): return None def visit_between_predicate(self, node, context): - if isinstance(node.value, QualifiedNameReference): - self.add_filter_column_with_qualified_name_reference( - node.value, 'between' - ) + if ( + isinstance(node.value, QualifiedNameReference) + and context & ParserUtils.CollectInfo.COLLECT_FILTER_COLUMN + ): + parts = node.value.name.parts + table_name = parts[-2] if len(parts) > 2 else None + self.add_filter_column(parts[-1], "between", table_name) return None def add_filter_column_with_qualified_name_reference( @@ -285,7 +306,7 @@ def add_filter_column_with_qualified_name_reference( ) visitor = FormatVisitor() - visitor.process(statement, None) + visitor.process(statement, 0) return visitor @staticmethod @@ -345,7 +366,7 @@ def visit_query_specification(self, node, context): return None visitor = Visitor() - visitor.process(statement, None) + visitor.process(statement, 0) return statement diff --git a/sqlgpt_parser/parser/tree/visitor.py b/sqlgpt_parser/parser/tree/visitor.py index 89a0e89..0e71b7e 100644 --- a/sqlgpt_parser/parser/tree/visitor.py +++ b/sqlgpt_parser/parser/tree/visitor.py @@ -25,7 +25,8 @@ def visit_node(self, node, context): pass def visit_expression(self, node, context): - return self.visit_node(node, context) + for arg in node.arguments: + self.process(node, arg) def visit_reset_session(self, node, context): return self.visit_statement(node, context) @@ -58,7 +59,7 @@ def visit_assignment_expression(self, node, context): return self.visit_expression(node, context) def visit_literal(self, node, context): - return self.visit_expression(node, context) + return None def visit_double_literal(self, node, context): return self.visit_literal(node, context) @@ -184,7 +185,7 @@ def visit_list_expression(self, node, context): return self.visit_expression(node, context) def visit_qualified_name_reference(self, node, context): - return self.visit_expression(node, context) + return None def visit_dereference_expression(self, node, context): return self.visit_expression(node, context) @@ -202,7 +203,7 @@ def visit_arithmetic_unary(self, node, context): return self.visit_expression(node, context) def visit_not_expression(self, node, context): - return self.visit_expression(node, context) + return self.process(node.value, context) def visit_select_item(self, node, context): return self.visit_node(node, context) @@ -217,16 +218,16 @@ def visit_searched_case_expression(self, node, context): return self.visit_expression(node, context) def visit_like_predicate(self, node, context): - return self.visit_expression(node, context) + return self.process(node.value, context) def visit_regexp_predicate(self, node, context): - return self.visit_expression(node, context) + return self.process(node.value, context) def visit_is_not_null_predicate(self, node, context): - return self.visit_expression(node, context) + return self.process(node.value, context) def visit_is_predicate(self, node, context): - return self.visit_expression(node, context) + return self.process(node.value, context) def visit_array_constructor(self, node, context): return self.visit_expression(node, context) @@ -235,7 +236,7 @@ def visit_subscript_expression(self, node, context): return self.visit_expression(node, context) def visit_long_literal(self, node, context): - return self.visit_literal(node, context) + return self.visit_literal(node.value, context) def visit_logical_binary_expression(self, node, context): return self.visit_expression(node, context) @@ -300,8 +301,14 @@ def visit_group_concat(self, node, context): def visit_input_reference(self, node, context): return self.visit_expression(node, context) + def visit_window_func(self, node, context): + for arg in node.func_args: + self.process(arg, context) + self.process(node.window_spec, context) + return None + def visit_window_spec(self, node, context): - return self.visit_node(node, context) + return None def visit_window_frame(self, node, context): return self.visit_node(node, context) diff --git a/test/format/test_sql_formatter.py b/test/format/test_sql_formatter.py index a6ae1ce..9d21a97 100644 --- a/test/format/test_sql_formatter.py +++ b/test/format/test_sql_formatter.py @@ -1,5 +1,7 @@ import unittest +from sqlgpt_parser.parser.parser_utils import ParserUtils + from sqlgpt_parser.format.formatter import format_sql from sqlgpt_parser.parser.mysql_parser import parser @@ -214,6 +216,42 @@ def test_windows_function(self): after_sql_rewrite_format = format_sql(statement, 0) assert after_sql_rewrite_format == except_sql + def test_sql(self): + test_sqls = [ + "update sqless_base set nick=1231 where a = 1 and b = 2 ", + "select 1,t2.a from t1 left join t2 where t1.d > 2 and t2.a =1", + ] + except_table_list = [ + [ + { + 'alias': '', + 'filter_column_list': [ + {'column_name': 'a', 'opt': '='}, + {'column_name': 'b', 'opt': '='}, + ], + 'table_name': 'sqless_base', + } + ], + [ + {'alias': '', 'filter_column_list': [], 'table_name': 't1'}, + { + 'alias': '', + 'filter_column_list': [ + {'column_name': 'd', 'opt': '>'}, + {'column_name': 'a', 'opt': '='}, + ], + 'table_name': 't2', + }, + ], + ] + except_projection_list = [[], ['a']] + for index, sql in enumerate(test_sqls): + statement = parser.parse(sql) + visitor = ParserUtils.format_statement(statement) + print(visitor) + visitor.table_list = except_table_list[index] + visitor.projection_column_list = except_projection_list[index] + if __name__ == '__main__': unittest.main()