Skip to content

Commit

Permalink
Merge pull request #7 from qidi1/fix-time-interval-format-error
Browse files Browse the repository at this point in the history
fix format time interval error
  • Loading branch information
qidi1 authored Sep 8, 2023
2 parents 619835e + b0a84c4 commit c930736
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 267 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Query(query_body=QuerySpecification(select=Select(distinct=False, select_items=[
### Format SQL
```python
>>> from sqlgpt_parser.format.formatter import format_sql
>>> from sqlgpt_parser.mysql_parser import parser
>>> from sqlgpt_parser.parser.mysql_parser import parser
>>> result=parser.parse("select * from t")
>>> format_sql(result)
'SELECT\n *\nFROM\n t'
Expand Down
257 changes: 0 additions & 257 deletions docs/docs-ch/Parser Module Development Guide.md

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "sqlgpt-parser"
version = "0.0.1a3"
version = "0.0.1a4"
authors = [
{ name="luliwjc", email="[email protected]" },
{ name="Ifffff", email="[email protected]" },
Expand Down
5 changes: 4 additions & 1 deletion sqlgpt_parser/format/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def visit_at_time_zone(self, node, context):
self.process(node.time_zone, context),
)

def visit_time_interval(self, node, context):
return f"INTERVAL {self.process(node.value, context)} {node.unit.upper()}"

def visit_current_time(self, node, unmangle_names):
return "%s%s" % (node.type, "(%s)" % node.precision if node.precision else "")

Expand Down Expand Up @@ -106,7 +109,7 @@ def visit_dereference_expression(self, node, unmangle_names):

def visit_function_call(self, node, unmangle_names):
ret = ""
arguments = self._join_expressions(node.args, unmangle_names)
arguments = self._join_expressions(node.arguments, unmangle_names)
if "count" == node.name.lower() and len(arguments) != 0 and arguments[0] == '*':
arguments = "*"
if node.distinct:
Expand Down
11 changes: 9 additions & 2 deletions sqlgpt_parser/parser/mysql_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
TrimFunc,
WhenClause,
JsonTableColumn,
TimeInterval,
)
from sqlgpt_parser.parser.tree.grouping import SimpleGroupBy
from sqlgpt_parser.parser.tree.join_criteria import JoinOn, JoinUsing, NaturalJoin
Expand All @@ -72,6 +73,7 @@
TimeLiteral,
DefaultLiteral,
ErrorLiteral,
TimestampLiteral,
)
from sqlgpt_parser.parser.tree.node import Node
from sqlgpt_parser.parser.tree.qualified_name import QualifiedName
Expand Down Expand Up @@ -689,7 +691,12 @@ def p_date_lit(p):
r"""date_lit : DATE string_lit
| TIME string_lit
| TIMESTAMP string_lit"""
p[0] = DateLiteral(p.lineno(1), p.lexpos(1), value=p[2], unit=p[1])
if p.slice[1].type.upper() == "DATE":
p[0] = DateLiteral(p.lineno(1), p.lexpos(1), value=p[2])
elif p.slice[1].type.upper() == "TIME":
p[0] = TimeLiteral(p.lineno(1), p.lexpos(1), value=p[2])
elif p.slice[1].type.upper() == "TIMESTAMP":
p[0] = TimestampLiteral(p.lineno(1), p.lexpos(1), value=p[2])


def p_order(p):
Expand Down Expand Up @@ -3437,7 +3444,7 @@ def p_time_interval(p):
r"""time_interval : INTERVAL expression time_unit
| QM"""
if len(p) == 4:
p[0] = TimeLiteral(p.lineno(1), p.lexpos(1), value=p[2], unit=p[3])
p[0] = TimeInterval(p.lineno(1), p.lexpos(1), value=p[2], unit=p[3])
else:
p[0] = p[1]

Expand Down
11 changes: 9 additions & 2 deletions sqlgpt_parser/parser/oceanbase_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
TrimFunc,
WhenClause,
JsonTableColumn,
TimeInterval,
)
from sqlgpt_parser.parser.tree.grouping import SimpleGroupBy
from sqlgpt_parser.parser.tree.join_criteria import JoinOn, JoinUsing, NaturalJoin
Expand All @@ -72,6 +73,7 @@
TimeLiteral,
DefaultLiteral,
ErrorLiteral,
TimestampLiteral,
)
from sqlgpt_parser.parser.tree.node import Node
from sqlgpt_parser.parser.tree.qualified_name import QualifiedName
Expand Down Expand Up @@ -686,7 +688,12 @@ def p_date_lit(p):
r"""date_lit : DATE string_lit
| TIME string_lit
| TIMESTAMP string_lit"""
p[0] = DateLiteral(p.lineno(1), p.lexpos(1), value=p[2], unit=p[1])
if p.slice[1].type.upper() == "DATE":
p[0] = DateLiteral(p.lineno(1), p.lexpos(1), value=p[2])
elif p.slice[1].type.upper() == "TIME":
p[0] = TimeLiteral(p.lineno(1), p.lexpos(1), value=p[2])
elif p.slice[1].type.upper() == "TIMESTAMP":
p[0] = TimestampLiteral(p.lineno(1), p.lexpos(1), value=p[2])


def p_order(p):
Expand Down Expand Up @@ -3859,7 +3866,7 @@ def p_time_interval(p):
r"""time_interval : INTERVAL expression time_unit
| QM"""
if len(p) == 4:
p[0] = TimeLiteral(p.lineno(1), p.lexpos(1), value=p[2], unit=p[3])
p[0] = TimeInterval(p.lineno(1), p.lexpos(1), value=p[2], unit=p[3])
else:
p[0] = p[1]

Expand Down
Loading

0 comments on commit c930736

Please sign in to comment.