-
Notifications
You must be signed in to change notification settings - Fork 298
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix some types when downstream enabled this package in mypy (#815)
* fix: select arg type * fix some type error when mypy enable * revert * revert * fix
- Loading branch information
Showing
3 changed files
with
88 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,9 +112,63 @@ | |
FunctionException, | ||
) | ||
|
||
|
||
__author__ = "Timothy Heys" | ||
__email__ = "[email protected]" | ||
__version__ = "0.48.9" | ||
|
||
NULL = NullValue() | ||
SYSTEM_TIME = SystemTimeValue() | ||
|
||
__all__ = ( | ||
'ClickHouseQuery', | ||
'Dialects', | ||
'MSSQLQuery', | ||
'MySQLQuery', | ||
'OracleQuery', | ||
'PostgreSQLQuery', | ||
'RedshiftQuery', | ||
'SQLLiteQuery', | ||
'VerticaQuery', | ||
'DatePart', | ||
'JoinType', | ||
'Order', | ||
'AliasedQuery', | ||
'Query', | ||
'Schema', | ||
'Table', | ||
'Column', | ||
'Database', | ||
'Tables', | ||
'Columns', | ||
'Array', | ||
'Bracket', | ||
'Case', | ||
'Criterion', | ||
'EmptyCriterion', | ||
'Field', | ||
'Index', | ||
'Interval', | ||
'JSON', | ||
'Not', | ||
'NullValue', | ||
'SystemTimeValue', | ||
'Parameter', | ||
'QmarkParameter', | ||
'NumericParameter', | ||
'NamedParameter', | ||
'FormatParameter', | ||
'PyformatParameter', | ||
'Rollup', | ||
'Tuple', | ||
'CustomFunction', | ||
'CaseException', | ||
'GroupingException', | ||
'JoinException', | ||
'QueryException', | ||
'RollupException', | ||
'SetOperationException', | ||
'FunctionException', | ||
'NULL', | ||
'SYSTEM_TIME', | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,11 @@ | ||
""" | ||
Package for SQL functions wrappers | ||
""" | ||
from __future__ import annotations | ||
|
||
from typing import Optional | ||
|
||
from pypika import Field | ||
from pypika.enums import SqlTypes | ||
from pypika.terms import ( | ||
AggregateFunction, | ||
|
@@ -10,6 +15,7 @@ | |
) | ||
from pypika.utils import builder | ||
|
||
|
||
__author__ = "Timothy Heys" | ||
__email__ = "[email protected]" | ||
|
||
|
@@ -34,64 +40,64 @@ def distinct(self): | |
|
||
|
||
class Count(DistinctOptionFunction): | ||
def __init__(self, param, alias=None): | ||
def __init__(self, param: str | Field, alias: Optional[str] = None) -> None: | ||
is_star = isinstance(param, str) and "*" == param | ||
super(Count, self).__init__("COUNT", Star() if is_star else param, alias=alias) | ||
|
||
|
||
# Arithmetic Functions | ||
class Sum(DistinctOptionFunction): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Sum, self).__init__("SUM", term, alias=alias) | ||
|
||
|
||
class Avg(AggregateFunction): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Avg, self).__init__("AVG", term, alias=alias) | ||
|
||
|
||
class Min(AggregateFunction): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Min, self).__init__("MIN", term, alias=alias) | ||
|
||
|
||
class Max(AggregateFunction): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Max, self).__init__("MAX", term, alias=alias) | ||
|
||
|
||
class Std(AggregateFunction): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Std, self).__init__("STD", term, alias=alias) | ||
|
||
|
||
class StdDev(AggregateFunction): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(StdDev, self).__init__("STDDEV", term, alias=alias) | ||
|
||
|
||
class Abs(AggregateFunction): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Abs, self).__init__("ABS", term, alias=alias) | ||
|
||
|
||
class First(AggregateFunction): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(First, self).__init__("FIRST", term, alias=alias) | ||
|
||
|
||
class Last(AggregateFunction): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Last, self).__init__("LAST", term, alias=alias) | ||
|
||
|
||
class Sqrt(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Sqrt, self).__init__("SQRT", term, alias=alias) | ||
|
||
|
||
class Floor(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Floor, self).__init__("FLOOR", term, alias=alias) | ||
|
||
|
||
|
@@ -131,17 +137,17 @@ def __init__(self, term, as_type, alias=None): | |
|
||
|
||
class Signed(Cast): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Signed, self).__init__(term, SqlTypes.SIGNED, alias=alias) | ||
|
||
|
||
class Unsigned(Cast): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Unsigned, self).__init__(term, SqlTypes.UNSIGNED, alias=alias) | ||
|
||
|
||
class Date(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Date, self).__init__("DATE", term, alias=alias) | ||
|
||
|
||
|
@@ -156,7 +162,7 @@ def __init__(self, start_time, end_time, alias=None): | |
|
||
|
||
class DateAdd(Function): | ||
def __init__(self, date_part, interval, term, alias=None): | ||
def __init__(self, date_part, interval, term: str, alias: Optional[str] = None): | ||
date_part = getattr(date_part, "value", date_part) | ||
super(DateAdd, self).__init__("DATE_ADD", LiteralValue(date_part), interval, term, alias=alias) | ||
|
||
|
@@ -167,19 +173,19 @@ def __init__(self, value, format_mask, alias=None): | |
|
||
|
||
class Timestamp(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Timestamp, self).__init__("TIMESTAMP", term, alias=alias) | ||
|
||
|
||
class TimestampAdd(Function): | ||
def __init__(self, date_part, interval, term, alias=None): | ||
def __init__(self, date_part, interval, term: str, alias: Optional[str] = None): | ||
date_part = getattr(date_part, 'value', date_part) | ||
super(TimestampAdd, self).__init__("TIMESTAMPADD", LiteralValue(date_part), interval, term, alias=alias) | ||
|
||
|
||
# String Functions | ||
class Ascii(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Ascii, self).__init__("ASCII", term, alias=alias) | ||
|
||
|
||
|
@@ -189,7 +195,7 @@ def __init__(self, term, condition, **kwargs): | |
|
||
|
||
class Bin(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Bin, self).__init__("BIN", term, alias=alias) | ||
|
||
|
||
|
@@ -205,17 +211,17 @@ def __init__(self, term, start, stop, subterm, alias=None): | |
|
||
|
||
class Length(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Length, self).__init__("LENGTH", term, alias=alias) | ||
|
||
|
||
class Upper(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Upper, self).__init__("UPPER", term, alias=alias) | ||
|
||
|
||
class Lower(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Lower, self).__init__("LOWER", term, alias=alias) | ||
|
||
|
||
|
@@ -225,12 +231,12 @@ def __init__(self, term, start, stop, alias=None): | |
|
||
|
||
class Reverse(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Reverse, self).__init__("REVERSE", term, alias=alias) | ||
|
||
|
||
class Trim(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(Trim, self).__init__("TRIM", term, alias=alias) | ||
|
||
|
||
|
@@ -297,7 +303,7 @@ def get_special_params_sql(self, **kwargs): | |
|
||
# Null Functions | ||
class IsNull(Function): | ||
def __init__(self, term, alias=None): | ||
def __init__(self, term: str | Field, alias: Optional[str] = None): | ||
super(IsNull, self).__init__("ISNULL", term, alias=alias) | ||
|
||
|
||
|
@@ -312,5 +318,5 @@ def __init__(self, condition, term, **kwargs): | |
|
||
|
||
class NVL(Function): | ||
def __init__(self, condition, term, alias=None): | ||
def __init__(self, condition, term: str, alias: Optional[str] = None): | ||
super(NVL, self).__init__("NVL", condition, term, alias=alias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters