Skip to content

Commit

Permalink
Merge pull request #8 from qidi1/format_bug_fix
Browse files Browse the repository at this point in the history
add format for sound_like window_function and fix parserutils bug
  • Loading branch information
qidi1 authored Sep 20, 2023
2 parents c930736 + b0d5783 commit 939e17f
Show file tree
Hide file tree
Showing 13 changed files with 5,552 additions and 5,442 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "sqlgpt-parser"
version = "0.0.1a4"
version = "0.0.1a5"
authors = [
{ name="luliwjc", email="[email protected]" },
{ name="Ifffff", email="[email protected]" },
Expand Down
63 changes: 49 additions & 14 deletions sqlgpt_parser/format/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,24 +421,33 @@ def visit_list_expression(self, node, unmangle_names):
else:
return "(%s)" % self._join_expressions(node.values, unmangle_names)

def visit_window_func(self, node, unmangle_names):
args = ", ".join([self.process(arg, unmangle_names) for arg in node.func_args])
ignore_null = f" {node.ignore_null} NULLS" if node.ignore_null else ""
window_spec = " OVER " + self.process(node.window_spec, unmangle_names)
return f"{node.func_name.upper()}({args}){ignore_null}{window_spec}"

def visit_window_spec(self, node, unmangle_names):
parts = []
if node.window_name is not None:
return node.window_name

parts = []
if node.partition_by:
parts.append(
"PARTITION BY "
+ self._join_expressions(node.partition_by, unmangle_names)
)
self.process(node.partition_by, unmangle_names)
if node.order_by:
parts.append("ORDER BY " + format_sort_items(node.order_by, unmangle_names))
if node.frame:
parts.append(self.process(node.frame, unmangle_names))

if node.frame_clause:
parts.append(self.process(node.frame_clause, unmangle_names))
return '(' + ' '.join(parts) + ')'

def visit_window_frame(self, node, unmangle_names):
ret = node.type + " "
def visit_partition_by_clause(self, node, unmangle_names):
return "PARTITION BY " + self._join_expressions(node.items, unmangle_names)

def visit_frame_clause(self, node, unmangle_names):
return f"{node.type} {self.process(node.frame_range, unmangle_names)}"

def visit_window_frame(self, node, unmangle_names):
ret = ""
if node.end:
ret += "BETWEEN %s AND %s" % (
self.process(node.start, unmangle_names),
Expand All @@ -449,6 +458,19 @@ def visit_window_frame(self, node, unmangle_names):

return ret

def visit_frame_bound(self, node, unmangle_names):
if node.type.upper() == "ROW":
return "CURRENT ROW"
expr = (
self.process(node.expr, unmangle_names)
if node.expr is not None
else "UNBOUNDED "
)
return f"{expr} {node.type.upper()}"

def visit_frame_expr(self, node, unmangle_names):
return self.process(node.value, unmangle_names)

def visit_single_column(self, node, indent):
format_expression(node.expression)

Expand All @@ -468,6 +490,9 @@ def visit_match_against_expression(self, node, unmangle_names):
full_text_search_modifier = full_text_search_modifier.upper()
return f"MATCH({columns}) AGAINST ({self.process(node.expr, unmangle_names)}{full_text_search_modifier})"

def visit_sound_like(self, node, unmangle_names):
return f"{self.process(node.arguments[0])} SOUNDS LIKE {self.process(node.arguments[1])}"

def _format_binary_expression(self, operator, left, right, unmangle_names):
return "%s %s %s" % (
self.process(left, unmangle_names),
Expand Down Expand Up @@ -689,13 +714,14 @@ def visit_table_subquery(self, node, indent):
return None

def visit_union(self, node, indent):
all = node.all
for i, relation in enumerate(node.relations):
self._process_relation(relation, indent)
self.builder.append("\n")
if i != len(node.relations) - 1:
if all:
if node.all:
self._append(indent, "UNION ALL")
elif node.distinct:
self._append(indent, "UNION DISTINCT")
else:
self._append(indent, "UNION")
self.builder.append("\n")
Expand All @@ -704,7 +730,12 @@ def visit_union(self, node, indent):

def visit_except(self, node, indent):
self._process_relation(node.left, indent)
self.builder.append("EXCEPT " + "ALL " if not node.distinct else "")
if node.all is not None:
self._append(indent, "EXCEPT ALL")
elif node.distinct is not None:
self._append(indent, "EXCEPT DISTINCT")
else:
self._append(indent, "EXCEPT")
self._process_relation(node.right, indent)

return None
Expand Down Expand Up @@ -756,7 +787,11 @@ def visit_intersect(self, node, indent):
relations = [
self._process_relation(relation, indent) for relation in node.relations
]
intersect = "INTERSECT " + "ALL " if not node.distinct else ""
intersect = "INTERSECT"
if node.all is not None:
intersect += " ALL"
elif node.distinct is not None:
intersect += " DISTINCT"
self.builder.append(intersect.join(relations))
return None

Expand Down
22 changes: 10 additions & 12 deletions sqlgpt_parser/parser/mysql_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ def p_alias_opt(p):
if p.slice[1].type == "alias":
p[0] = p[1]
else:
p[0] = ()
p[0] = []


def p_alias(p):
Expand All @@ -1215,9 +1215,9 @@ def p_alias(p):
| AS string_lit
| string_lit"""
if len(p) == 3:
p[0] = (p[1], p[2])
p[0] = [p[1], p[2]]
else:
p[0] = p[1]
p[0] = [p[1]]


def p_expression(p):
Expand Down Expand Up @@ -1570,7 +1570,7 @@ def p_window_func_call(p):
| ROW_NUMBER LPAREN RPAREN over_clause
"""
length = len(p)
window_spec = p[-1]
window_spec = p[length-1]
args = []
ignore_null = None

Expand Down Expand Up @@ -1711,7 +1711,10 @@ def p_frame_start(p):
| frame_expr PRECEDING
| frame_expr FOLLOWING
"""
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
if p.slice[1].type == 'frame_expr':
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
else:
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=None)


def p_frame_end(p):
Expand All @@ -1726,13 +1729,8 @@ def p_frame_between(p):

def p_frame_expr(p):
r"""frame_expr : figure
| QM
| INTERVAL expression time_unit
|"""
if len(p) == 4:
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[2], unit=p[3])
else:
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])
| time_interval"""
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])


def p_lead_lag_info_opt(p):
Expand Down
2,968 changes: 1,483 additions & 1,485 deletions sqlgpt_parser/parser/mysql_parser/parser_table.py

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions sqlgpt_parser/parser/oceanbase_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,7 @@ def p_alias_opt(p):
if p.slice[1].type == "alias":
p[0] = p[1]
else:
p[0] = ()
p[0] = []


def p_alias(p):
Expand All @@ -1212,9 +1212,9 @@ def p_alias(p):
| AS string_lit
| string_lit"""
if len(p) == 3:
p[0] = (p[1], p[2])
p[0] = [p[1], p[2]]
else:
p[0] = p[1]
p[0] = [p[1]]


def p_expression(p):
Expand Down Expand Up @@ -1649,7 +1649,7 @@ def p_window_func_call(p):
| ROW_NUMBER LPAREN RPAREN over_clause
"""
length = len(p)
window_spec = p[-1]
window_spec = p[length-1]
args = []
ignore_null = None

Expand Down Expand Up @@ -1790,7 +1790,10 @@ def p_frame_start(p):
| frame_expr PRECEDING
| frame_expr FOLLOWING
"""
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
if p.slice[1].type == 'frame_expr':
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
else:
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=None)


def p_frame_end(p):
Expand All @@ -1802,12 +1805,9 @@ def p_frame_between(p):
r"""frame_between : BETWEEN frame_start AND frame_end"""
p[0] = WindowFrame(p.lineno(1), p.lexpos(1), start=p[2], end=p[4])


def p_frame_expr(p):
r"""frame_expr : figure
| QM
| time_interval
|"""
| time_interval"""
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])


Expand Down
3,790 changes: 1,894 additions & 1,896 deletions sqlgpt_parser/parser/oceanbase_parser/parser_table.py

Large diffs are not rendered by default.

17 changes: 8 additions & 9 deletions sqlgpt_parser/parser/odps_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ def p_alias_opt(p):
if p.slice[1].type == "alias":
p[0] = p[1]
else:
p[0] = ()
p[0] = []


def p_alias(p):
Expand All @@ -1214,7 +1214,7 @@ def p_alias(p):
| AS string_lit
| string_lit"""
if len(p) == 3:
p[0] = (p[1], p[2])
p[0] = [p[1], p[2]]
else:
p[0] = p[1]

Expand Down Expand Up @@ -1657,7 +1657,7 @@ def p_window_func_call(p):
| ROW_NUMBER LPAREN RPAREN over_clause
"""
length = len(p)
window_spec = p[-1]
window_spec = p[length-1]
args = []
ignore_null = None

Expand Down Expand Up @@ -1798,8 +1798,10 @@ def p_frame_start(p):
| frame_expr PRECEDING
| frame_expr FOLLOWING
"""
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])

if p.slice[1].type == 'frame_expr':
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
else:
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=None)

def p_frame_end(p):
r"""frame_end : frame_start"""
Expand All @@ -1810,12 +1812,9 @@ def p_frame_between(p):
r"""frame_between : BETWEEN frame_start AND frame_end"""
p[0] = WindowFrame(p.lineno(1), p.lexpos(1), start=p[2], end=p[4])


def p_frame_expr(p):
r"""frame_expr : figure
| QM
| time_interval
|"""
| time_interval"""
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])


Expand Down
3,798 changes: 1,898 additions & 1,900 deletions sqlgpt_parser/parser/odps_parser/parser_table.py

Large diffs are not rendered by default.

Loading

0 comments on commit 939e17f

Please sign in to comment.