Skip to content

Commit

Permalink
Typing (#1056)
Browse files Browse the repository at this point in the history
* typing fixes

* Set -> AbstractSet

* typing fixes

* typing fixes

* typing fixes

* typing fixes

* typing fixes

* Make column_states properties non-optional

* add else clause

* revert adding else clause to fix dependent_column_names being an empty set

* Make Attribute.datatype non-optional, remove casts

* Add back else clause

* remove assertion

* session keyword argument

* session keyword argument

* Add pyright in merge gates

* Change job name

* Fix CopyIntoTableNode.transformations type hints

* Remove comment

* Small fixes

* Address comments

* Use logical_plan.source_data directly

---------

Co-authored-by: Sophie Tan <[email protected]>
  • Loading branch information
alexmojaki and sfc-gh-stan authored Oct 18, 2023
1 parent 5a01033 commit a78e3ec
Show file tree
Hide file tree
Showing 20 changed files with 228 additions and 144 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/precommit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ jobs:
- name: Run fix_lint
run: python -m tox -e fix_lint

type_checking:
name: Type Checking
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
- name: Display Python version
run: python -c "import sys; import os; print(\"\n\".join(os.environ[\"PATH\"].split(os.pathsep))); print(sys.version); print(sys.executable);"
- name: Upgrade setuptools and pip
run: python -m pip install -U setuptools pip
- name: Install tox
run: python -m pip install tox
- name: Run pyright on Selected Files
run: python -m tox -e pyright

build:
needs: lint
name: Build Wheel File
Expand Down
32 changes: 21 additions & 11 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import uuid
from collections import Counter, defaultdict
from typing import DefaultDict, Dict, Union
from typing import TYPE_CHECKING, DefaultDict, Dict, Optional, Union

import snowflake.snowpark
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
Expand Down Expand Up @@ -146,14 +146,17 @@

ARRAY_BIND_THRESHOLD = 512

if TYPE_CHECKING:
import snowflake.snowpark.session


class Analyzer:
def __init__(self, session: "snowflake.snowpark.session.Session") -> None:
self.session = session
self.plan_builder = SnowflakePlanBuilder(self.session)
self.generated_alias_maps = {}
self.subquery_plans = []
self.alias_maps_to_use = None
self.alias_maps_to_use: Optional[Dict[uuid.UUID, str]] = None

def analyze(
self,
Expand Down Expand Up @@ -331,6 +334,7 @@ def analyze(
return sql

if isinstance(expr, Attribute):
assert self.alias_maps_to_use is not None
name = self.alias_maps_to_use.get(expr.expr_id, expr.name)
return quote_name(name)

Expand Down Expand Up @@ -528,7 +532,7 @@ def analyze(
def table_function_expression_extractor(
self,
expr: TableFunctionExpression,
df_aliased_col_name_to_real_col_name: Dict[str, str],
df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]],
parse_local_name=False,
) -> str:
if isinstance(expr, FlattenFunction):
Expand Down Expand Up @@ -577,13 +581,14 @@ def table_function_expression_extractor(
def unary_expression_extractor(
self,
expr: UnaryExpression,
df_aliased_col_name_to_real_col_name: Dict[str, str],
df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]],
parse_local_name=False,
) -> str:
if isinstance(expr, Alias):
quoted_name = quote_name(expr.name)
if isinstance(expr.child, Attribute):
self.generated_alias_maps[expr.child.expr_id] = quoted_name
assert self.alias_maps_to_use is not None
for k, v in self.alias_maps_to_use.items():
if v == expr.child.name:
self.generated_alias_maps[k] = quoted_name
Expand Down Expand Up @@ -678,7 +683,7 @@ def window_frame_boundary(self, offset: str) -> str:
def to_sql_avoid_offset(
self,
expr: Expression,
df_aliased_col_name_to_real_col_name: Dict[str, str],
df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]],
parse_local_name: bool = False,
) -> str:
# if expression is a numeric literal, return the number without casting,
Expand Down Expand Up @@ -749,7 +754,7 @@ def do_resolve_with_resolved_children(
self,
logical_plan: LogicalPlan,
resolved_children: Dict[LogicalPlan, SnowflakePlan],
df_aliased_col_name_to_real_col_name: Dict[str, str],
df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]],
) -> SnowflakePlan:
if isinstance(logical_plan, SnowflakePlan):
return logical_plan
Expand Down Expand Up @@ -988,9 +993,10 @@ def do_resolve_with_resolved_children(
if logical_plan.format_type_options
else {}
)
format_name = logical_plan.cur_options.get("FORMAT_NAME")
format_name = (logical_plan.cur_options or {}).get("FORMAT_NAME")
if format_name is not None:
format_type_options["FORMAT_NAME"] = format_name
assert logical_plan.file_format is not None
return self.plan_builder.copy_into_table(
path=logical_plan.file_path,
table_name=logical_plan.table_name,
Expand Down Expand Up @@ -1041,7 +1047,7 @@ def do_resolve_with_resolved_children(
)
if logical_plan.condition
else None,
resolved_children.get(logical_plan.source_data, None),
logical_plan.source_data,
logical_plan,
)

Expand All @@ -1053,14 +1059,14 @@ def do_resolve_with_resolved_children(
)
if logical_plan.condition
else None,
resolved_children.get(logical_plan.source_data, None),
logical_plan.source_data,
logical_plan,
)

if isinstance(logical_plan, TableMerge):
return self.plan_builder.merge(
logical_plan.table_name,
resolved_children.get(logical_plan.source),
logical_plan.source,
self.analyze(
logical_plan.join_expr, df_aliased_col_name_to_real_col_name
),
Expand All @@ -1073,3 +1079,7 @@ def do_resolve_with_resolved_children(

if isinstance(logical_plan, Selectable):
return self.plan_builder.select_statement(logical_plan)

raise TypeError(
f"Cannot resolve type logical_plan of {type(logical_plan).__name__} to a SnowflakePlan"
)
12 changes: 6 additions & 6 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import math
import sys
import typing
from typing import Any, Dict, List, Optional, Tuple, Union

from snowflake.snowpark._internal.analyzer.binary_plan_node import (
Expand Down Expand Up @@ -450,7 +450,7 @@ def range_statement(start: int, end: int, step: int, column_name: str) -> str:
if range * step < 0:
count = 0
else:
count = range / step + (1 if range % step != 0 and range * step > 0 else 0)
count = math.ceil(range / step)

return project_statement(
[
Expand Down Expand Up @@ -854,7 +854,7 @@ def drop_file_format_if_exists_statement(format_name: str) -> str:


def select_from_path_with_format_statement(
project: List[str], path: str, format_name: str, pattern: str
project: List[str], path: str, format_name: str, pattern: Optional[str]
) -> str:
select_statement = (
SELECT + (STAR if not project else COMMA.join(project)) + FROM + path
Expand Down Expand Up @@ -909,7 +909,7 @@ def window_frame_boundary_expression(offset: str, is_following: bool) -> str:


def rank_related_function_expression(
func_name: str, expr: str, offset: int, default: str, ignore_nulls: bool
func_name: str, expr: str, offset: int, default: Optional[str], ignore_nulls: bool
) -> str:
return (
func_name
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def copy_into_table(
file_format_type: str,
format_type_options: Dict[str, Any],
copy_options: Dict[str, Any],
pattern: str,
pattern: Optional[str],
*,
files: Optional[str] = None,
validation_mode: Optional[str] = None,
Expand Down Expand Up @@ -1363,7 +1363,7 @@ def string(length: Optional[int] = None) -> str:


def get_file_format_spec(
file_format_type: str, format_type_options: typing.Dict[str, Any]
file_format_type: str, format_type_options: Dict[str, Any]
) -> str:
file_format_name = format_type_options.get("FORMAT_NAME")
file_format_str = FILE_FORMAT + EQUALS + LEFT_PARENTHESIS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Optional, Set
from typing import AbstractSet, Optional

from snowflake.snowpark._internal.analyzer.expression import (
Expression,
Expand All @@ -23,7 +23,7 @@ def __init__(self, left: Expression, right: Expression) -> None:
def __str__(self):
return f"{self.left} {self.sql_operator} {self.right}"

def dependent_column_names(self) -> Optional[Set[str]]:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.left, self.right)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def to_sql(value: Any, datatype: DataType, from_values_statement: bool = False)
return f"TIME('{trimmed_ms}')"

if isinstance(value, (list, bytes, bytearray)) and isinstance(datatype, BinaryType):
return f"'{binascii.hexlify(value).decode()}' :: BINARY"
return f"'{binascii.hexlify(bytes(value)).decode()}' :: BINARY"

if isinstance(value, (list, tuple, array)) and isinstance(datatype, ArrayType):
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: ARRAY"
Expand Down
Loading

0 comments on commit a78e3ec

Please sign in to comment.