From a50b3676c3741ba1f6fac0a288b9498531bd0930 Mon Sep 17 00:00:00 2001 From: Trim21 Date: Tue, 19 Nov 2024 01:00:52 +0800 Subject: [PATCH] fix some types when downstream enabled this package in mypy (#815) * fix: select arg type * fix some type error when mypy enable * revert * revert * fix --- pypika/__init__.py | 54 ++++++++++++++++++++++++++++++++++++++++ pypika/functions.py | 60 +++++++++++++++++++++++++-------------------- pypika/queries.py | 2 +- 3 files changed, 88 insertions(+), 28 deletions(-) diff --git a/pypika/__init__.py b/pypika/__init__.py index a875ae00..66f564f0 100644 --- a/pypika/__init__.py +++ b/pypika/__init__.py @@ -112,9 +112,63 @@ FunctionException, ) + __author__ = "Timothy Heys" __email__ = "theys@kayak.com" __version__ = "0.48.9" NULL = NullValue() SYSTEM_TIME = SystemTimeValue() + +__all__ = ( + 'ClickHouseQuery', + 'Dialects', + 'MSSQLQuery', + 'MySQLQuery', + 'OracleQuery', + 'PostgreSQLQuery', + 'RedshiftQuery', + 'SQLLiteQuery', + 'VerticaQuery', + 'DatePart', + 'JoinType', + 'Order', + 'AliasedQuery', + 'Query', + 'Schema', + 'Table', + 'Column', + 'Database', + 'Tables', + 'Columns', + 'Array', + 'Bracket', + 'Case', + 'Criterion', + 'EmptyCriterion', + 'Field', + 'Index', + 'Interval', + 'JSON', + 'Not', + 'NullValue', + 'SystemTimeValue', + 'Parameter', + 'QmarkParameter', + 'NumericParameter', + 'NamedParameter', + 'FormatParameter', + 'PyformatParameter', + 'Rollup', + 'Tuple', + 'CustomFunction', + 'CaseException', + 'GroupingException', + 'JoinException', + 'QueryException', + 'RollupException', + 'SetOperationException', + 'FunctionException', + 'NULL', + 'SYSTEM_TIME', +) diff --git a/pypika/functions.py b/pypika/functions.py index c8200408..5e693f0d 100644 --- a/pypika/functions.py +++ b/pypika/functions.py @@ -1,6 +1,11 @@ """ Package for SQL functions wrappers """ +from __future__ import annotations + +from typing import Optional + +from pypika import Field from pypika.enums import SqlTypes from pypika.terms import ( AggregateFunction, @@ -10,6 +15,7 @@ ) from pypika.utils import builder + __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -34,64 +40,64 @@ def distinct(self): class Count(DistinctOptionFunction): - def __init__(self, param, alias=None): + def __init__(self, param: str | Field, alias: Optional[str] = None) -> None: is_star = isinstance(param, str) and "*" == param super(Count, self).__init__("COUNT", Star() if is_star else param, alias=alias) # Arithmetic Functions class Sum(DistinctOptionFunction): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Sum, self).__init__("SUM", term, alias=alias) class Avg(AggregateFunction): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Avg, self).__init__("AVG", term, alias=alias) class Min(AggregateFunction): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Min, self).__init__("MIN", term, alias=alias) class Max(AggregateFunction): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Max, self).__init__("MAX", term, alias=alias) class Std(AggregateFunction): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Std, self).__init__("STD", term, alias=alias) class StdDev(AggregateFunction): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(StdDev, self).__init__("STDDEV", term, alias=alias) class Abs(AggregateFunction): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Abs, self).__init__("ABS", term, alias=alias) class First(AggregateFunction): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(First, self).__init__("FIRST", term, alias=alias) class Last(AggregateFunction): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Last, self).__init__("LAST", term, alias=alias) class Sqrt(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Sqrt, self).__init__("SQRT", term, alias=alias) class Floor(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Floor, self).__init__("FLOOR", term, alias=alias) @@ -131,17 +137,17 @@ def __init__(self, term, as_type, alias=None): class Signed(Cast): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Signed, self).__init__(term, SqlTypes.SIGNED, alias=alias) class Unsigned(Cast): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Unsigned, self).__init__(term, SqlTypes.UNSIGNED, alias=alias) class Date(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Date, self).__init__("DATE", term, alias=alias) @@ -156,7 +162,7 @@ def __init__(self, start_time, end_time, alias=None): class DateAdd(Function): - def __init__(self, date_part, interval, term, alias=None): + def __init__(self, date_part, interval, term: str, alias: Optional[str] = None): date_part = getattr(date_part, "value", date_part) super(DateAdd, self).__init__("DATE_ADD", LiteralValue(date_part), interval, term, alias=alias) @@ -167,19 +173,19 @@ def __init__(self, value, format_mask, alias=None): class Timestamp(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Timestamp, self).__init__("TIMESTAMP", term, alias=alias) class TimestampAdd(Function): - def __init__(self, date_part, interval, term, alias=None): + def __init__(self, date_part, interval, term: str, alias: Optional[str] = None): date_part = getattr(date_part, 'value', date_part) super(TimestampAdd, self).__init__("TIMESTAMPADD", LiteralValue(date_part), interval, term, alias=alias) # String Functions class Ascii(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Ascii, self).__init__("ASCII", term, alias=alias) @@ -189,7 +195,7 @@ def __init__(self, term, condition, **kwargs): class Bin(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Bin, self).__init__("BIN", term, alias=alias) @@ -205,17 +211,17 @@ def __init__(self, term, start, stop, subterm, alias=None): class Length(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Length, self).__init__("LENGTH", term, alias=alias) class Upper(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Upper, self).__init__("UPPER", term, alias=alias) class Lower(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Lower, self).__init__("LOWER", term, alias=alias) @@ -225,12 +231,12 @@ def __init__(self, term, start, stop, alias=None): class Reverse(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Reverse, self).__init__("REVERSE", term, alias=alias) class Trim(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(Trim, self).__init__("TRIM", term, alias=alias) @@ -297,7 +303,7 @@ def get_special_params_sql(self, **kwargs): # Null Functions class IsNull(Function): - def __init__(self, term, alias=None): + def __init__(self, term: str | Field, alias: Optional[str] = None): super(IsNull, self).__init__("ISNULL", term, alias=alias) @@ -312,5 +318,5 @@ def __init__(self, condition, term, **kwargs): class NVL(Function): - def __init__(self, condition, term, alias=None): + def __init__(self, condition, term: str, alias: Optional[str] = None): super(NVL, self).__init__("NVL", condition, term, alias=alias) diff --git a/pypika/queries.py b/pypika/queries.py index c51c6b2b..d7861200 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -206,7 +206,7 @@ def __ne__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(str(self)) - def select(self, *terms: Sequence[Union[int, float, str, bool, Term, Field]]) -> "QueryBuilder": + def select(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilder": """ Perform a SELECT operation on the current table