From e0ff34d33d2f603d59b8cc4e7538485491c11eb5 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 21 May 2024 17:21:22 -0700 Subject: [PATCH 01/37] Calculate query complexity --- .../snowpark/_internal/analyzer/expression.py | 14 ++++++ .../_internal/analyzer/select_statement.py | 50 +++++++++++++++++++ .../_internal/analyzer/snowflake_plan.py | 11 ++++ 3 files changed, 75 insertions(+) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 6084f358312..915e62d4f34 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -3,6 +3,7 @@ # import copy +from functools import cached_property import uuid from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple @@ -80,6 +81,19 @@ def sql(self) -> str: ) return f"{self.pretty_name}({children_sql})" + @cached_property + def total_children_count(self) -> int: + count = 0 + current_layer = [self] + while current_layer: + next_layer = [] + for expression in current_layer: + count += 1 + if expression.children: + next_layer.extend(expression.children) + current_layer = next_layer + return count + def __str__(self) -> str: return self.pretty_name diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 833009648b6..cc27773df90 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -6,6 +6,7 @@ from collections import UserDict, defaultdict from copy import copy, deepcopy from enum import Enum +from functools import cached_property from typing import ( TYPE_CHECKING, AbstractSet, @@ -290,6 +291,25 @@ def plan_height(self) -> int: def num_duplicate_nodes(self) -> int: return self.snowflake_plan.num_duplicate_nodes + @property + def individual_query_complexity(self) -> int: + """This is the query complexity estimate added by this Selectable node + to the overall query plan. For default case, it is the number of active + columns. Specific cases are handled in child classes with additional + explanation. + """ + return len(self.column_states.active_columns) + + @cached_property + def subtree_query_complexity(self) -> int: + """This is sum of individual query complexity estimates for all nodes + within a query plan subtree. + """ + estimate = self.individual_query_complexity + for child in self.children_plan_nodes: + estimate += child.subtree_query_complexity + return estimate + @property def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: """ @@ -348,6 +368,11 @@ def sql_in_subquery(self) -> str: def schema_query(self) -> str: return self.sql_query + @property + def individual_query_complexity(self) -> int: + # select * from entity has 1 char '*' to represent columns + return 1 + @property def query_params(self) -> Optional[Sequence[Any]]: return None @@ -403,6 +428,16 @@ def query_params(self) -> Optional[Sequence[Any]]: def schema_query(self) -> str: return self._schema_query + @property + def individual_query_complexity(self): + if self.pre_actions: + # having pre-actions implies we have a non-select query followed by a + # select * from table(result_scan) statement + return 1 + + # no pre-action implies the best estimate we have is of # active columns + return len(self.column_states.active_columns) + def to_subqueryable(self) -> "SelectSQL": """Convert this SelectSQL to a new one that can be used as a subquery. Refer to __init__.""" if self.convert_to_select or is_sql_select_statement(self._sql_query): @@ -914,6 +949,16 @@ def limit(self, n: int, *, offset: int = 0) -> "SelectStatement": new.post_actions = new.from_.post_actions return new + @property + def individual_query_complexity(self) -> int: + # projection component + estimate = 1 if self.projection is None else len(self.projection) + # order by component + estimate += 0 if self.order_by is None else len(self.order_by) + # filter component + estimate += 0 if self.where is None else self.where.total_children_count + return estimate + class SelectTableFunction(Selectable): """Wrap table function related plan to a subclass of Selectable.""" @@ -1039,6 +1084,11 @@ def query_params(self) -> Optional[Sequence[Any]]: def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return self._nodes + @property + def individual_query_complexity(self) -> int: + # we add #set_operands - 1 additional operators in sql query + return len(self.set_operands) - 1 + class DeriveColumnDependencyError(Exception): """When deriving column dependencies from the subquery.""" diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 9ea2d64dd61..8134a247d11 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -350,6 +350,17 @@ def plan_height(self) -> int: def num_duplicate_nodes(self) -> int: return len(find_duplicate_subtrees(self)) + @cached_property + def individual_query_complexity(self) -> int: + return len(self.output) + + @cached_property + def subtree_query_complexity(self) -> int: + estimate = self.individual_query_complexity + for child in self.children_plan_nodes: + estimate += child.subtree_query_complexity + return estimate + def __copy__(self) -> "SnowflakePlan": if self.session._cte_optimization_enabled: return SnowflakePlan( From 7f714e3581d8a08fa49f8dd34ce3afa2867f9371 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 22 May 2024 15:46:20 -0700 Subject: [PATCH 02/37] make subtree computation iterative; telemetry --- .../snowpark/_internal/analyzer/cte_utils.py | 12 +++++++++++ .../_internal/analyzer/select_statement.py | 20 ++++++++++++------- .../_internal/analyzer/snowflake_plan.py | 15 +++++++++----- src/snowflake/snowpark/_internal/telemetry.py | 4 ++++ 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py index 4c4c50e899f..cd502dbef15 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -182,3 +182,15 @@ def encode_id( except Exception as ex: logging.warning(f"Encode SnowflakePlan ID failed: {ex}") return None + + +def compute_subtree_query_complexity(node: "TreeNode") -> int: + current_level = [node] + estimate = 0 + while current_level: + next_level = [] + for node in current_level: + estimate += node.individual_query_complexity + next_level.extend(node.children_plan_nodes) + current_level = next_level + return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index cc27773df90..0320edfe110 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -6,7 +6,6 @@ from collections import UserDict, defaultdict from copy import copy, deepcopy from enum import Enum -from functools import cached_property from typing import ( TYPE_CHECKING, AbstractSet, @@ -21,7 +20,10 @@ ) import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.cte_utils import encode_id +from snowflake.snowpark._internal.analyzer.cte_utils import ( + compute_subtree_query_complexity, + encode_id, +) from snowflake.snowpark._internal.analyzer.table_function import ( TableFunctionExpression, TableFunctionJoin, @@ -201,6 +203,7 @@ def __init__( str, Dict[str, str] ] = defaultdict(dict) self._api_calls = api_calls.copy() if api_calls is not None else None + self._subtree_query_complexity = None def __eq__(self, other: "Selectable") -> bool: if self._id is not None and other._id is not None: @@ -300,15 +303,18 @@ def individual_query_complexity(self) -> int: """ return len(self.column_states.active_columns) - @cached_property + @property def subtree_query_complexity(self) -> int: """This is sum of individual query complexity estimates for all nodes within a query plan subtree. """ - estimate = self.individual_query_complexity - for child in self.children_plan_nodes: - estimate += child.subtree_query_complexity - return estimate + if self._subtree_query_complexity is None: + self._subtree_query_complexity = compute_subtree_query_complexity(self) + return self._subtree_query_complexity + + @subtree_query_complexity.setter + def subtree_query_complexity(self, value: int): + self._subtree_query_complexity = value @property def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 8134a247d11..7f34ae4fa83 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -77,6 +77,7 @@ SetOperation, ) from snowflake.snowpark._internal.analyzer.cte_utils import ( + compute_subtree_query_complexity, create_cte_query, encode_id, find_duplicate_subtrees, @@ -232,6 +233,7 @@ def __init__( self.placeholder_query = placeholder_query # encode an id for CTE optimization self._id = encode_id(queries[-1].sql, queries[-1].params) + self._subtree_query_complexity = None def __eq__(self, other: "SnowflakePlan") -> bool: if self._id is not None and other._id is not None: @@ -354,12 +356,15 @@ def num_duplicate_nodes(self) -> int: def individual_query_complexity(self) -> int: return len(self.output) - @cached_property + @property def subtree_query_complexity(self) -> int: - estimate = self.individual_query_complexity - for child in self.children_plan_nodes: - estimate += child.subtree_query_complexity - return estimate + if self._subtree_query_complexity is None: + self._subtree_query_complexity = compute_subtree_query_complexity(self) + return self._subtree_query_complexity + + @subtree_query_complexity.setter + def subtree_query_complexity(self, value: int): + self._subtree_query_complexity = value def __copy__(self) -> "SnowflakePlan": if self.session._cte_optimization_enabled: diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 5ce349c82eb..c5a313d2707 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -68,6 +68,7 @@ class TelemetryField(Enum): # dataframe query stats QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" + QUERY_PLAN_COMPLEXITY_ESTIMATE = "query_plan_complexity_estimate" # These DataFrame APIs call other DataFrame APIs @@ -160,6 +161,9 @@ def wrap(*args, **kwargs): api_calls[0][ TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value ] = plan.num_duplicate_nodes + api_calls[0][ + TelemetryField.QUERY_PLAN_COMPLEXITY_ESTIMATE + ] = plan.subtree_query_complexity except Exception: pass args[0]._session._conn._telemetry_client.send_function_usage_telemetry( From b279769e8837715ed4bfc3de61278c38faf243d5 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 28 May 2024 21:28:05 -0700 Subject: [PATCH 03/37] compute complexity for expressions --- .../_internal/analyzer/binary_plan_node.py | 14 +++ .../snowpark/_internal/analyzer/expression.py | 99 ++++++++++++++++--- .../_internal/analyzer/grouping_set.py | 5 + .../_internal/analyzer/select_statement.py | 35 ++++--- .../_internal/analyzer/snowflake_plan.py | 4 - .../_internal/analyzer/snowflake_plan_node.py | 71 +++++++++++++ .../_internal/analyzer/sort_expression.py | 8 ++ .../_internal/analyzer/table_function.py | 68 +++++++++++++ .../analyzer/table_merge_expression.py | 56 +++++++++++ .../_internal/analyzer/unary_plan_node.py | 77 ++++++++++++++- .../_internal/analyzer/window_expression.py | 39 ++++++++ 11 files changed, 444 insertions(+), 32 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index 46212fc6113..4c56f53fcf7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -187,3 +187,17 @@ def __init__( @property def sql(self) -> str: return self.join_type.sql + + @property + def individual_query_complexity(self) -> int: + # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond + estimate = 3 + if isinstance(self.join_type, UsingJoin): + estimate += 1 + len(self.join_type.using_columns) + estimate += ( + self.join_condition.expression_complexity if self.join_condition else 0 + ) + estimate += ( + self.match_condition.expression_complexity if self.match_condition else 0 + ) + return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 915e62d4f34..e00e47de1c1 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -3,8 +3,8 @@ # import copy -from functools import cached_property import uuid +from functools import cached_property from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple import snowflake.snowpark._internal.utils @@ -82,17 +82,8 @@ def sql(self) -> str: return f"{self.pretty_name}({children_sql})" @cached_property - def total_children_count(self) -> int: - count = 0 - current_layer = [self] - while current_layer: - next_layer = [] - for expression in current_layer: - count += 1 - if expression.children: - next_layer.extend(expression.children) - current_layer = next_layer - return count + def expression_complexity(self) -> int: + return 1 + sum(expr.expression_complexity for expr in (self.children or [])) def __str__(self) -> str: return self.pretty_name @@ -113,6 +104,10 @@ def __copy__(self): new._expr_id = None # type: ignore return new + @property + def expression_complexity(self) -> int: + return 1 + class ScalarSubquery(Expression): def __init__(self, plan: "SnowflakePlan") -> None: @@ -122,6 +117,11 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR + @cached_property + def expression_complexity(self) -> int: + # get plan complexity + return self.plan.subtree_query_complexity + class MultipleExpression(Expression): def __init__(self, expressions: List[Expression]) -> None: @@ -131,6 +131,10 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + @cached_property + def expression_complexity(self) -> int: + return sum(expr.expression_complexity for expr in self.expressions) + class InExpression(Expression): def __init__(self, columns: Expression, values: List[Expression]) -> None: @@ -141,6 +145,12 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) + @cached_property + def expression_complexity(self) -> int: + return self.columns.expression_complexity + sum( + expr.expression_complexity for expr in self.values + ) + class Attribute(Expression, NamedExpression): def __init__(self, name: str, datatype: DataType, nullable: bool = True) -> None: @@ -169,6 +179,10 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} + @property + def expression_complexity(self) -> int: + return 1 + class Star(Expression): def __init__( @@ -181,6 +195,10 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + @cached_property + def expression_complexity(self) -> int: + return max(1, sum(expr.expression_complexity for expr in self.expressions)) + class UnresolvedAttribute(Expression, NamedExpression): def __init__( @@ -286,6 +304,10 @@ def sql(self) -> str: def __str__(self) -> str: return self.sql + @cached_property + def expression_complexity(self) -> int: + return len(self.values_dict) + class Like(Expression): def __init__(self, expr: Expression, pattern: Expression) -> None: @@ -296,6 +318,10 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + @cached_property + def expression_complexity(self) -> int: + return self.expr.expression_complexity + self.pattern.expression_complexity + class RegExp(Expression): def __init__(self, expr: Expression, pattern: Expression) -> None: @@ -306,6 +332,10 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + @cached_property + def expression_complexity(self) -> int: + return self.expr.expression_complexity + self.pattern.expression_complexity + class Collate(Expression): def __init__(self, expr: Expression, collation_spec: str) -> None: @@ -316,6 +346,10 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @cached_property + def expression_complexity(self) -> int: + return self.expr.expression_complexity + 1 + class SubfieldString(Expression): def __init__(self, expr: Expression, field: str) -> None: @@ -326,6 +360,10 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @cached_property + def expression_complexity(self) -> int: + return self.expr.expression_complexity + 1 + class SubfieldInt(Expression): def __init__(self, expr: Expression, field: int) -> None: @@ -336,6 +374,10 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @cached_property + def expression_complexity(self) -> int: + return self.expr.expression_complexity + 1 + class FunctionExpression(Expression): def __init__( @@ -368,6 +410,12 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + @cached_property + def expression_complexity(self) -> int: + estimate = sum(expr.expression_complexity for expr in self.children) + estimate += 1 if self.is_distinct else 0 + return estimate + class WithinGroup(Expression): def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: @@ -379,6 +427,12 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) + @cached_property + def expression_complexity(self) -> int: + return self.expr.expression_complexity + sum( + expr.expression_complexity for expr in self.order_by_cols + ) + class CaseWhen(Expression): def __init__( @@ -398,6 +452,17 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: exps.append(self.else_value) return derive_dependent_columns(*exps) + @cached_property + def expression_complexity(self) -> int: + estimate = sum[ + ( + condition.expression_complexity + value.expression_complexity + for condition, value in self.branches + ) + ] + estimate += self.else_value if self.else_value else 0 + return estimate + class SnowflakeUDF(Expression): def __init__( @@ -418,6 +483,10 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + @cached_property + def expression_complexity(self) -> int: + return 1 + sum(expr.expression_complexity for expr in self.children) + class ListAgg(Expression): def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: @@ -428,3 +497,9 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) + + @cached_property + def expression_complexity(self) -> int: + estimate = self.col.expression_complexity + 1 + estimate += 1 if self.is_distinct else 0 + return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 653fbde1ca3..1f323563fec 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from functools import cached_property from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( @@ -36,3 +37,7 @@ def __init__(self, args: List[List[Expression]]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) + + @cached_property + def expression_complexity(self) -> int: + return sum(sum(expr.expression_complexity for expr in arg) for arg in self.args) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 0320edfe110..0daa5bd4c8f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -203,7 +203,7 @@ def __init__( str, Dict[str, str] ] = defaultdict(dict) self._api_calls = api_calls.copy() if api_calls is not None else None - self._subtree_query_complexity = None + self._subtree_query_complexity: Optional[int] = None def __eq__(self, other: "Selectable") -> bool: if self._id is not None and other._id is not None: @@ -301,7 +301,11 @@ def individual_query_complexity(self) -> int: columns. Specific cases are handled in child classes with additional explanation. """ - return len(self.column_states.active_columns) + return ( + self.snowflake_plan.source_plan.individual_query_complexity + if self.snowflake_plan.source_plan + else len(self.column_states.active_columns) + ) @property def subtree_query_complexity(self) -> int: @@ -376,7 +380,7 @@ def schema_query(self) -> str: @property def individual_query_complexity(self) -> int: - # select * from entity has 1 char '*' to represent columns + # SELECT * FROM entity return 1 @property @@ -438,7 +442,7 @@ def schema_query(self) -> str: def individual_query_complexity(self): if self.pre_actions: # having pre-actions implies we have a non-select query followed by a - # select * from table(result_scan) statement + # SELECT * FROM table(result_scan(query_id)) statement return 1 # no pre-action implies the best estimate we have is of # active columns @@ -699,6 +703,19 @@ def schema_query(self) -> str: def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return [self.from_] + @property + def individual_query_complexity(self) -> int: + # projection component + estimate = sum(expr.expression_complexity for expr in self.projection) if self.projection else 0 + # order by component - add complexity for each sort expression but remove len(order_by) - 1 since we only + # include "ORDER BY" once in sql test + estimate += sum(expr.expression_complexity for expr in self.order_by) - (len(self.order_by) - 1) if self.order_by else 0 + # filter component - add +1 for WHERE clause and sum of expression complexity for where expression + estimate += (1 + self.where.expression_complexity) if self.where else 0 + # limit component + estimate += 1 if self.limit_ else 0 + return estimate + def to_subqueryable(self) -> "Selectable": """When this SelectStatement's subquery is not subqueryable (can't be used in `from` clause of the sql), convert it to subqueryable and create a new SelectStatement with from_ being the new subqueryable怂 @@ -955,16 +972,6 @@ def limit(self, n: int, *, offset: int = 0) -> "SelectStatement": new.post_actions = new.from_.post_actions return new - @property - def individual_query_complexity(self) -> int: - # projection component - estimate = 1 if self.projection is None else len(self.projection) - # order by component - estimate += 0 if self.order_by is None else len(self.order_by) - # filter component - estimate += 0 if self.where is None else self.where.total_children_count - return estimate - class SelectTableFunction(Selectable): """Wrap table function related plan to a subclass of Selectable.""" diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 7f34ae4fa83..ba810e8024b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -352,10 +352,6 @@ def plan_height(self) -> int: def num_duplicate_nodes(self) -> int: return len(find_duplicate_subtrees(self)) - @cached_property - def individual_query_complexity(self) -> int: - return len(self.output) - @property def subtree_query_complexity(self) -> int: if self._subtree_query_complexity is None: diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index ee23f391af9..1bc5c056b5a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -24,6 +24,10 @@ class LogicalPlan: def __init__(self) -> None: self.children = [] + @property + def individual_query_complexity(self) -> int: + return 1 + class LeafNode(LogicalPlan): pass @@ -39,6 +43,11 @@ def __init__(self, start: int, end: int, step: int, num_slices: int = 1) -> None self.step = step self.num_slices = num_slices + @property + def individual_query_complexity(self) -> int: + # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (step) + (start) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => count))) + return 6 + class UnresolvedRelation(LeafNode): def __init__(self, name: str) -> None: @@ -58,6 +67,13 @@ def __init__( self.data = data self.schema_query = schema_query + @property + def individual_query_complexity(self) -> int: + # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) + # (n+1) * m + # TODO: use ARRAY_BIND_THRESHOLD + return (len(self.data) + 1) * len(self.output) + class SaveMode(Enum): APPEND = "append" @@ -88,6 +104,17 @@ def __init__( self.clustering_exprs = clustering_exprs or [] self.comment = comment + @property + def individual_query_complexity(self) -> int: + estimate = 1 # mode is always present + # column estimate + estimate += 0 if self.column_names else len(self.column_names) + # clustering exprs + estimate += sum(expr.expression_complexity for expr in self.clustering_exprs) + # comment estimate + estimate += 0 if self.comment else 1 + return estimate + class Limit(LogicalPlan): def __init__( @@ -99,6 +126,14 @@ def __init__( self.child = child self.children.append(child) + @property + def individual_query_complexity(self) -> int: + # for limit and offset + return ( + self.limit_expr.expression_complexity + + self.offset_expr.expression_complexity + ) + class CopyIntoTableNode(LeafNode): def __init__( @@ -133,6 +168,24 @@ def __init__( self.cur_options = cur_options self.create_table_from_infer_schema = create_table_from_infer_schema + @property + def individual_query_complexity(self) -> int: + # for columns + estimate = len(self.column_names) if self.column_names else 0 + # for transformations + estimate += ( + len(expr.expression_complexity for expr in self.transformations) + if self.transformations + else 0 + ) + # for pattern + estimate += 1 if self.pattern else 0 + # for files + estimate += len(self.files) if self.files else 0 + # for copy options + estimate += len(self.copy_options) if self.copy_options else 0 + return estimate + class CopyIntoLocationNode(LogicalPlan): def __init__( @@ -157,3 +210,21 @@ def __init__( self.file_format_name = file_format_name self.file_format_type = file_format_type self.copy_options = copy_options + + @property + def individual_query_complexity(self) -> int: + # for stage location + estimate = 1 + # for partition + estimate += self.partition_by.expression_complexity if self.partition_by else 0 + # for file format name + estimate += 1 if self.file_format_name else 0 + # for file format type + estimate += 1 if self.file_format_type else 0 + # for file format options + estimate += len(self.format_type_options) if self.format_type_options else 0 + # for copy options + estimate += len(self.copy_options) + # for header + estimate += 1 if self.header else 0 + return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 1d06f7290a0..e0c39f6e626 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from functools import cached_property from typing import AbstractSet, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( @@ -55,3 +56,10 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + + @cached_property + def expression_complexity(self) -> int: + # ORDER BY child [null ordering] + estimate = self.child.expression_complexity + 1 + estimate += 1 if self.null_ordering else 0 + return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 2c6381ed345..b65644e8ed1 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -3,6 +3,7 @@ # import sys +from functools import cached_property from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression @@ -30,6 +31,23 @@ def __init__( self.partition_spec = partition_spec self.order_spec = order_spec + @cached_property + def expression_complexity(self) -> int: + if not self.over: + return 0 + + estimate = ( + sum(expr.expression_complexity for expr in self.partition_spec) + if self.partition_spec + else 0 + ) + estimate += ( + sum(expr.expression_complexity for expr in self.order_spec) + if self.order_spec + else 0 + ) + return estimate + class TableFunctionExpression(Expression): def __init__( @@ -45,6 +63,12 @@ def __init__( self.aliases = aliases self.api_call_source = api_call_source + @cached_property + def expression_complexity(self) -> int: + return ( + 1 + self.partition_spec.expression_complexity if self.partition_spec else 0 + ) + class FlattenFunction(TableFunctionExpression): def __init__( @@ -57,6 +81,10 @@ def __init__( self.recursive = recursive self.mode = mode + @cached_property + def expression_complexity(self) -> int: + return self.input.expression_complexity + 4 + class PosArgumentsTableFunction(TableFunctionExpression): def __init__( @@ -68,6 +96,14 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args + @cached_property + def expression_complexity(self) -> int: + estimate = 1 + sum(expr.expression_complexity for expr in self.args) + estimate += ( + self.partition_spec.expression_complexity if self.partition_spec else 0 + ) + return estimate + class NamedArgumentsTableFunction(TableFunctionExpression): def __init__( @@ -79,6 +115,14 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args + @cached_property + def expression_complexity(self) -> int: + estimate = 1 + sum((1 + arg.expression_complexity) for arg in self.args.values()) + estimate += ( + self.partition_spec.expression_complexity if self.partition_spec else 0 + ) + return estimate + class GeneratorTableFunction(TableFunctionExpression): def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: @@ -86,12 +130,24 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: self.args = args self.operators = operators + @cached_property + def expression_complexity(self) -> int: + return ( + 1 + + sum(1 + arg.expression_complexity for arg in self.args.values()) + + len(self.operators) + ) + class TableFunctionRelation(LogicalPlan): def __init__(self, table_function: TableFunctionExpression) -> None: super().__init__() self.table_function = table_function + @property + def individual_query_complexity(self) -> int: + return self.table_function.expression_complexity + class TableFunctionJoin(LogicalPlan): def __init__( @@ -107,6 +163,14 @@ def __init__( self.left_cols = left_cols if left_cols is not None else ["*"] self.right_cols = right_cols if right_cols is not None else ["*"] + @property + def individual_query_complexity(self) -> int: + return ( + self.table_function.expression_complexity + + len(self.left_cols) + + len(self.right_cols) + ) + class Lateral(LogicalPlan): def __init__( @@ -115,3 +179,7 @@ def __init__( super().__init__() self.children = [child] self.table_function = table_function + + @property + def individual_query_complexity(self) -> int: + return 1 + self.table_function.expression_complexity diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index 2d2554e43a1..baf3b24905b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from functools import cached_property from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression @@ -16,6 +17,13 @@ def __init__(self, condition: Optional[Expression]) -> None: super().__init__() self.condition = condition + @cached_property + def expression_complexity(self) -> int: + # WHEN MATCHED [AND condition] THEN DEL + estimate = 4 + estimate += self.condition.expression_complexity if self.condition else 0 + return estimate + class UpdateMergeExpression(MergeExpression): def __init__( @@ -24,6 +32,17 @@ def __init__( super().__init__(condition) self.assignments = assignments + @cached_property + def expression_complexity(self) -> int: + # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) + estimate = 4 + estimate += 1 if self.condition else 0 + estimate += sum( + key_expr.expression_complexity + val_expr.expression_complexity + for key_expr, val_expr in self.assignments.items() + ) + return estimate + class DeleteMergeExpression(MergeExpression): pass @@ -40,6 +59,15 @@ def __init__( self.keys = keys self.values = values + @cached_property + def expression_complexity(self) -> int: + # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) + estimate = 5 + estimate += sum(expr.expression_complexity for expr in self.keys) + estimate += sum(expr.expression_complexity for expr in self.values) + estimate += self.condition.expression_complexity if self.condition else 0 + return estimate + class TableUpdate(LogicalPlan): def __init__( @@ -56,6 +84,18 @@ def __init__( self.source_data = source_data self.children = [source_data] if source_data else [] + @cached_property + def individual_query_complexity(self) -> int: + # UPDATE table_name SET COMMA.join(k, v in assignments) [source_data] [WHERE condition] + estimate = 2 + estimate += sum( + key_expr.expression_complexity + val_expr.expression_complexity + for key_expr, val_expr in self.assignments.items() + ) + estimate += self.condition.expression_complexity if self.condition else 0 + # note that source data will be handled by subtree aggregator since it is added as a child + return estimate + class TableDelete(LogicalPlan): def __init__( @@ -70,6 +110,13 @@ def __init__( self.source_data = source_data self.children = [source_data] if source_data else [] + @cached_property + def individual_query_complexity(self) -> int: + # DELETE FROM table_name [USING source_data] [WHERE condition] + estimate = 2 + estimate += self.condition.expression_complexity if self.condition else 0 + return estimate + class TableMerge(LogicalPlan): def __init__( @@ -85,3 +132,12 @@ def __init__( self.join_expr = join_expr self.clauses = clauses self.children = [source] if source else [] + + @cached_property + def individual_query_complexity(self) -> int: + # MERGE INTO table_name USING (source) ON join_expr clauses + return ( + 4 + + self.join_expr.expression_complexity + + sum(expr.expression_complexity for expr in self.clauses) + ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index fddf3caa8b3..29977dff97f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -4,7 +4,11 @@ from typing import Dict, List, Optional, Union -from snowflake.snowpark._internal.analyzer.expression import Expression, NamedExpression +from snowflake.snowpark._internal.analyzer.expression import ( + Expression, + NamedExpression, + ScalarSubquery, +) from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -29,12 +33,23 @@ def __init__( self.row_count = row_count self.seed = seed + @property + def individual_query_complexity(self) -> int: + # child SAMPLE (probability) -- if probability is provided + # child SAMPLE (row_count ROWS) -- if not probability but row count is provided + return 2 + 1 if self.row_count else 0 + class Sort(UnaryNode): def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None: super().__init__(child) self.order = order + @property + def individual_query_complexity(self) -> int: + # child ORDER BY COMMA.join(order) + return 1 + sum(expr.expression_complexity for expr in self.order) + class Aggregate(UnaryNode): def __init__( @@ -47,13 +62,19 @@ def __init__( self.grouping_expressions = grouping_expressions self.aggregate_expressions = aggregate_expressions + @property + def individual_query_complexity(self) -> int: + return sum( + expr.expression_complexity for expr in self.grouping_expressions + ) + sum(expr.expression_complexity for expr in self.aggregate_expressions) + class Pivot(UnaryNode): def __init__( self, grouping_columns: List[Expression], pivot_column: Expression, - pivot_values: Optional[Union[List[Expression], LogicalPlan]], + pivot_values: Optional[Union[List[Expression], ScalarSubquery]], aggregates: List[Expression], default_on_null: Optional[Expression], child: LogicalPlan, @@ -65,6 +86,28 @@ def __init__( self.aggregates = aggregates self.default_on_null = default_on_null + @property + def individual_query_complexity(self) -> int: + estimate = sum(expr.expression_complexity for expr in self.grouping_columns) + estimate += self.pivot_column.expression_complexity + if isinstance(self.pivot_values, ScalarSubquery): + estimate += self.pivot_values.expression_complexity + elif isinstance(self.pivot_values, List): + estimate += sum(expr.expression_complexity for expr in self.pivot_values) + else: + # when pivot values is None + estimate += 1 + + if len(self.aggregates) > 0: + estimate += self.aggregates[0].expression_complexity + + estimate += ( + self.default_on_null.expression_complexity + 1 + if self.default_on_null + else 0 + ) + return estimate + class Unpivot(UnaryNode): def __init__( @@ -79,6 +122,10 @@ def __init__( self.name_column = name_column self.column_list = column_list + @property + def individual_query_complexity(self) -> int: + return 2 + sum(expr.expression_complexity for expr in self.column_list) + class Rename(UnaryNode): def __init__( @@ -89,18 +136,30 @@ def __init__( super().__init__(child) self.column_map = column_map + @property + def individual_query_complexity(self) -> int: + return 2 * len(self.column_map) + class Filter(UnaryNode): def __init__(self, condition: Expression, child: LogicalPlan) -> None: super().__init__(child) self.condition = condition + @property + def individual_query_complexity(self) -> int: + return self.condition.expression_complexity + class Project(UnaryNode): def __init__(self, project_list: List[NamedExpression], child: LogicalPlan) -> None: super().__init__(child) self.project_list = project_list + @property + def individual_query_complexity(self) -> int: + return sum(expr.expression_complexity for expr in self.project_list) + class ViewType: def __str__(self): @@ -128,6 +187,13 @@ def __init__( self.view_type = view_type self.comment = comment + @property + def individual_query_complexity(self) -> int: + estimate = 3 + estimate += 1 if isinstance(self.view_type, LocalTempView) else 0 + estimate += 1 if self.comment else 0 + return estimate + class CreateDynamicTableCommand(UnaryNode): def __init__( @@ -143,3 +209,10 @@ def __init__( self.warehouse = warehouse self.lag = lag self.comment = comment + + @property + def individual_query_complexity(self) -> int: + # CREATE OR REPLACE DYNAMIC TABLE name LAG = lag WAREHOUSE = wh [comment] AS child + estimate = 7 + estimate += 1 if self.comment else 0 + return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 452919f3313..056272c08af 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from functools import cached_property from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( @@ -63,6 +64,11 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) + @cached_property + def expression_complexity(self) -> int: + # frame_type BETWEEN lower AND upper + return 2 + self.lower.expression_complexity + self.upper.expression_complexity + class WindowSpecDefinition(Expression): def __init__( @@ -81,6 +87,22 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) + @cached_property + def expression_complexity(self) -> int: + # PARTITION BY COMMA.join(exprs) ORDER BY frame_spec + estimate = self.frame_spec.expression_complexity + estimate += ( + (1 + sum(expr.expression_complexity for expr in self.partition_spec)) + if self.partition_spec + else 0 + ) + estimate += ( + (1 + sum(expr.expression_complexity for expr in self.order_spec)) + if self.order_spec + else 0 + ) + return estimate + class WindowExpression(Expression): def __init__( @@ -93,6 +115,14 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) + @cached_property + def expression_complexity(self) -> int: + # window_function OVER ( window_spec ) + return ( + self.window_function.expression_complexity + + self.window_spec.expression_complexity + ) + class RankRelatedFunctionExpression(Expression): sql: str @@ -113,6 +143,15 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) + @cached_property + def expression_complexity(self) -> int: + # func_name (expr [, offset] [, default]) [IGNORE NULLS] + estimate = 1 + self.expr.expression_complexity + estimate += 1 if self.offset else 0 + estimate += self.default.expression_complexity if self.default else 0 + estimate += 1 if self.ignore_nulls else 0 + return estimate + class Lag(RankRelatedFunctionExpression): sql = "LAG" From 4582bb2c1aeb910d1d9849023c19f45ef92edae3 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 29 May 2024 09:13:38 -0700 Subject: [PATCH 04/37] add tests --- .../_internal/analyzer/select_statement.py | 15 +- .../_internal/analyzer/table_function.py | 8 +- .../_internal/analyzer/unary_plan_node.py | 9 +- tests/integ/test_materialization_suite.py | 285 ++++++++++++++++++ 4 files changed, 309 insertions(+), 8 deletions(-) create mode 100644 tests/integ/test_materialization_suite.py diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 0daa5bd4c8f..e92ebe6644c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -706,10 +706,21 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: @property def individual_query_complexity(self) -> int: # projection component - estimate = sum(expr.expression_complexity for expr in self.projection) if self.projection else 0 + estimate = ( + sum(expr.expression_complexity for expr in self.projection) + if self.projection + else 0 + ) # order by component - add complexity for each sort expression but remove len(order_by) - 1 since we only # include "ORDER BY" once in sql test - estimate += sum(expr.expression_complexity for expr in self.order_by) - (len(self.order_by) - 1) if self.order_by else 0 + estimate += ( + ( + sum(expr.expression_complexity for expr in self.order_by) + - (len(self.order_by) - 1) + ) + if self.order_by + else 0 + ) # filter component - add +1 for WHERE clause and sum of expression complexity for where expression estimate += (1 + self.where.expression_complexity) if self.where else 0 # limit component diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index b65644e8ed1..0cac7e62a87 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -65,8 +65,8 @@ def __init__( @cached_property def expression_complexity(self) -> int: - return ( - 1 + self.partition_spec.expression_complexity if self.partition_spec else 0 + return 1 + ( + self.partition_spec.expression_complexity if self.partition_spec else 0 ) @@ -117,7 +117,9 @@ def __init__( @cached_property def expression_complexity(self) -> int: - estimate = 1 + sum((1 + arg.expression_complexity) for arg in self.args.values()) + estimate = 1 + sum( + (1 + arg.expression_complexity) for arg in self.args.values() + ) estimate += ( self.partition_spec.expression_complexity if self.partition_spec else 0 ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 29977dff97f..c4681afd695 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -37,7 +37,7 @@ def __init__( def individual_query_complexity(self) -> int: # child SAMPLE (probability) -- if probability is provided # child SAMPLE (row_count ROWS) -- if not probability but row count is provided - return 2 + 1 if self.row_count else 0 + return 2 + (1 if self.row_count else 0) class Sort(UnaryNode): @@ -88,7 +88,9 @@ def __init__( @property def individual_query_complexity(self) -> int: - estimate = sum(expr.expression_complexity for expr in self.grouping_columns) + # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) + estimate = 3 + estimate += sum(expr.expression_complexity for expr in self.grouping_columns) estimate += self.pivot_column.expression_complexity if isinstance(self.pivot_values, ScalarSubquery): estimate += self.pivot_values.expression_complexity @@ -124,7 +126,8 @@ def __init__( @property def individual_query_complexity(self) -> int: - return 2 + sum(expr.expression_complexity for expr in self.column_list) + # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) + return 4 + sum(expr.expression_complexity for expr in self.column_list) class Rename(UnaryNode): diff --git a/tests/integ/test_materialization_suite.py b/tests/integ/test_materialization_suite.py new file mode 100644 index 00000000000..f53f0d51e28 --- /dev/null +++ b/tests/integ/test_materialization_suite.py @@ -0,0 +1,285 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import pytest + +from snowflake.snowpark._internal.analyzer.select_statement import ( + SET_EXCEPT, + SET_INTERSECT, + SET_UNION, + SET_UNION_ALL, +) +from snowflake.snowpark.dataframe import DataFrame +from snowflake.snowpark.functions import avg, col, lit, seq1, table_function, uniform +from snowflake.snowpark.session import Session +from snowflake.snowpark.window import Window +from tests.utils import Utils + + +@pytest.fixture(autouse=True) +def setup(session): + is_simplifier_enabled = session._sql_simplifier_enabled + session._sql_simplifier_enabled = True + yield + session._sql_simplifier_enabled = is_simplifier_enabled + + +@pytest.fixture(scope="module") +def sample_table(session): + table_name = Utils.random_table_name() + Utils.create_table( + session, table_name, "a int, b int, c int, d int", is_temporary=True + ) + session._run_query( + f"insert into {table_name}(a, b, c, d) values " "(1, 2, 3, 4), (5, 6, 7, 8)" + ) + yield table_name + Utils.drop_table(session, table_name) + + +def get_subtree_query_complexity(df: DataFrame) -> int: + return df._select_statement.subtree_query_complexity + + +def assert_df_subtree_query_complexity(df: DataFrame, estimate: int): + assert ( + get_subtree_query_complexity(df) == estimate + ), f"query = {df.queries['queries'][-1]}" + + +def test_create_dataframe_from_values(session: Session): + df1 = session.create_dataframe([[1], [2], [3]], schema=["a"]) + # SELECT "A" FROM ( SELECT $1 AS "A" FROM VALUES (1 :: INT), (2 :: INT), (3 :: INT)) + assert_df_subtree_query_complexity(df1, 5) + + df2 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"]) + # SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT), (5 :: INT, 6 :: INT)) + assert_df_subtree_query_complexity(df2, 10) + + +def test_session_table(session: Session, sample_table: str): + df = session.table(sample_table) + # select * from sample_table + assert_df_subtree_query_complexity(df, 1) + + +def test_range_statement(session: Session): + df = session.range(1, 5, 2) + # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (2) + (1) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => 2))) + assert_df_subtree_query_complexity(df, 6) + + +def test_generator_table_function(session: Session): + df1 = session.generator( + seq1(1).as_("seq"), uniform(1, 10, 2).as_("uniform"), rowcount=150 + ) + assert_df_subtree_query_complexity(df1, 5) + + df2 = df1.order_by("seq") + # adds SELECT * from () ORDER BY seq ASC NULLS FIRST + assert_df_subtree_query_complexity( + df2, df1._select_statement.subtree_query_complexity + 3 + ) + + +def test_join_table_function(session: Session): + df1 = session.sql( + "select 'James' as name, 'address1 address2 address3' as addresses" + ) + # SelectSQL chooses num active columns as the best estimate + # assert_df_subtree_query_complexity(df1, 2) + + split_to_table = table_function("split_to_table") + df2 = df1.select(split_to_table(col("addresses"), lit(" "))) + # +3 SELECT "SEQ", "INDEX", "VALUE" FROM ( + # +3 SELECT T_RIGHT."SEQ", T_RIGHT."INDEX", T_RIGHT."VALUE" FROM + # +2 (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT + # +3 JOIN TABLE (split_to_table("ADDRESSES", ' ') ) AS T_RIGHT) + assert_df_subtree_query_complexity(df2, 11) + + +@pytest.mark.parametrize( + "set_operator", [SET_UNION, SET_UNION_ALL, SET_EXCEPT, SET_INTERSECT] +) +def test_set_operators(session: Session, sample_table: str, set_operator: str): + df1 = session.table(sample_table) + df2 = session.table(sample_table) + if set_operator == SET_UNION: + df = df1.union(df2) + elif set_operator == SET_UNION_ALL: + df = df1.union_all(df2) + elif set_operator == SET_EXCEPT: + df = df1.except_(df2) + else: + df = df1.intersect(df2) + + # ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) set_operator ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) + assert_df_subtree_query_complexity(df, 3) + + +def test_agg(session: Session, sample_table: str): + df = session.table(sample_table) + df1 = df.agg(avg("a")) + df2 = df.agg(avg("a") + 1) + df3 = df.agg(avg("a"), avg("b" + lit(1)).as_("avg_b")) + + # SELECT avg("A") AS "AVG(A)" FROM ( SELECT * FROM sample_table) LIMIT 1 + assert_df_subtree_query_complexity(df1, 3) + # SELECT (avg("A") + 1 :: INT) AS "ADD(AVG(A), LITERAL())" FROM ( SELECT * FROM sample_table) LIMIT 1 + assert_df_subtree_query_complexity(df2, 5) + # SELECT avg("A") AS "AVG(A)", avg(('b' + 1 :: INT)) AS "AVG_B" FROM ( SELECT * FROM sample_table) LIMIT 1 + assert_df_subtree_query_complexity(df3, 7) + + +def test_window_function(session: Session): + window1 = ( + Window.partition_by("value").order_by("key").rows_between(Window.CURRENT_ROW, 2) + ) + window2 = Window.order_by(col("key").desc()).range_between( + Window.UNBOUNDED_PRECEDING, Window.UNBOUNDED_FOLLOWING + ) + df = session.create_dataframe( + [(1, "1"), (2, "2"), (1, "3"), (2, "4")], schema=["key", "value"] + ) + + df1 = df.select(avg("value").over(window1).as_("window1")) + # SELECT avg("VALUE") OVER (PARTITION BY "VALUE" ORDER BY "KEY" ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING ) AS "WINDOW1" FROM ( base_df) + assert_df_subtree_query_complexity(df1, get_subtree_query_complexity(df) + 10) + + # SELECT avg("VALUE") OVER ( ORDER BY "KEY" DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS "WINDOW2" FROM ( base df) + df2 = df1.select(avg("value").over(window2).as_("window2")) + assert_df_subtree_query_complexity(df2, get_subtree_query_complexity(df1) + 10) + + +def test_join_statement(session: Session, sample_table: str): + df1 = session.table(sample_table) + df2 = session.create_dataframe([[1, 2, 5], [3, 4, 9]], schema=["a", "b", "e"]) + base_complexity_sum = ( + df1._plan.subtree_query_complexity + df2._plan.subtree_query_complexity + ) + + df3 = df1.join(df2) + # SELECT * FROM (df1 AS SNOWPARK_LEFT INNER JOIN df2 AS SNOWPARK_RIGHT) + assert_df_subtree_query_complexity(df3, base_complexity_sum + 5) + + df4 = df1.join(df2, on=((df1["a"] == df2["a"]) & (df1["b"] == df2["b"]))) + # SELECT * FROM (df1 AS SNOWPARK_LEFT INNER JOIN df2 AS SNOWPARK_RIGHT ON (("l_xxhj_A" = "r_d4df_A") AND ("l_xxhj_B" = "r_d4df_B"))) + assert_df_subtree_query_complexity(df4, base_complexity_sum + 11) + + df5 = df1.join(df2, using_columns=["a", "b"]) + # SELECT * FROM (df1 AS SNOWPARK_LEFT INNER JOIN df2 AS SNOWPARK_RIGHT USING (a, b)) + assert_df_subtree_query_complexity(df5, base_complexity_sum + 7) + + +def test_pivot_and_unpivot(session: Session): + try: + session.sql( + """create or replace temp table monthly_sales(empid int, amount int, month text) + as select * from values + (1, 10000, 'JAN'), + (1, 400, 'JAN'), + (2, 4500, 'JAN'), + (2, 35000, 'JAN'), + (1, 5000, 'FEB'), + (1, 3000, 'FEB'), + (2, 200, 'FEB')""" + ).collect() + + df_pivot1 = ( + session.table("monthly_sales").pivot("month", ["JAN", "FEB"]).sum("amount") + ) + # SELECT * FROM ( SELECT * FROM monthly_sales) PIVOT (sum("AMOUNT") FOR "MONTH" IN ('JAN', 'FEB')) + assert_df_subtree_query_complexity(df_pivot1, 8) + + df_pivot2 = ( + session.table("monthly_sales") + .pivot("month", ["JAN", "FEB", "MARCH"]) + .sum("amount") + ) + # SELECT * FROM ( SELECT * FROM monthly_sales) PIVOT (sum("AMOUNT") FOR "MONTH" IN ('JAN', 'FEB', 'MARCH')) + assert_df_subtree_query_complexity(df_pivot2, 9) + + session.sql( + """create or replace temp table sales_for_month(empid int, dept varchar, jan int, feb int) + as select * from values + (1, 'electronics', 100, 200), + (2, 'clothes', 100, 300)""" + ).collect() + df_unpivot1 = session.table("sales_for_month").unpivot( + "sales", "month", ["jan", "feb"] + ) + # SELECT * FROM ( SELECT * FROM (sales_for_month)) UNPIVOT (sales FOR month IN ("JAN", "FEB")) + assert_df_subtree_query_complexity(df_unpivot1, 7) + finally: + Utils.drop_table(session, "monthly_sales") + Utils.drop_table(session, "sales_for_month") + + +def test_sample(session: Session, sample_table): + df = session.table(sample_table) + df_sample_frac = df.sample(0.5) + # SELECT * FROM ( SELECT * FROM (sample_table)) SAMPLE (50.0) + assert_df_subtree_query_complexity(df_sample_frac, 3) + + df_sample_rows = df.sample(n=1) + # SELECT * FROM ( SELECT * FROM (sample_table)) SAMPLE (1 ROWS) + assert_df_subtree_query_complexity(df_sample_rows, 4) + + +@pytest.mark.parametrize("source_from_table", [True, False]) +def test_select_statement_subtree_complexity_estimate( + session: Session, sample_table: str, source_from_table: bool +): + if source_from_table: + df1 = session.table(sample_table) + else: + df1 = session.create_dataframe( + [[1, 2, 3, 4], [5, 6, 7, 8]], schema=["a", "b", "c", "d"] + ) + + assert_df_subtree_query_complexity(df1, 1 if source_from_table else 16) + + # add select + # +3 for column + df2 = df1.select("a", "b", "c") + assert_df_subtree_query_complexity(df2, 4 if source_from_table else 15) + + # +2 for column (1 less active column) + df3 = df2.select("b", "c") + assert_df_subtree_query_complexity(df3, 3 if source_from_table else 14) + + # add sort + # +3 for additional ORDER BY "B" ASC NULLS FIRST + df4 = df3.sort(col("b").asc()) + assert_df_subtree_query_complexity(df4, 3 + get_subtree_query_complexity(df3)) + + # +3 for additional ,"C" ASC NULLS FIRST + df5 = df4.sort(col("c").desc()) + assert_df_subtree_query_complexity(df5, 2 + get_subtree_query_complexity(df4)) + + # add filter + # +4 for WHERE ("B" > 2) + df6 = df5.filter(col("b") > 2) + assert_df_subtree_query_complexity(df6, 4 + get_subtree_query_complexity(df5)) + + # +4 for filter - AND ("C" > 3) + df7 = df6.filter(col("c") > 3) + assert_df_subtree_query_complexity(df7, 4 + get_subtree_query_complexity(df6)) + + # add set operations + # +2 for 2 unions, 12 for sum of individual df complexity + df8 = df3.union_all(df4).union_all(df5) + assert_df_subtree_query_complexity( + df8, 2 + sum(get_subtree_query_complexity(df) for df in [df3, df4, df5]) + ) + + # + 2 for 2 unions, 30 for sum ob individual df complexity + df9 = df8.union_all(df6).union_all(df7) + assert_df_subtree_query_complexity( + df9, 2 + sum(get_subtree_query_complexity(df) for df in [df6, df7, df8]) + ) + + # +1 for limit + df10 = df9.limit(2) + assert_df_subtree_query_complexity(df10, 1 + get_subtree_query_complexity(df9)) From 7abd06c81790c192d7f26cf6f9bc195b056edfd6 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 29 May 2024 17:19:15 -0700 Subject: [PATCH 05/37] tests passing --- .../snowpark/_internal/analyzer/analyzer.py | 2 +- .../_internal/analyzer/binary_plan_node.py | 5 +++- .../_internal/analyzer/snowflake_plan.py | 10 +++++-- .../_internal/analyzer/snowflake_plan_node.py | 7 ++++- .../_internal/analyzer/unary_expression.py | 12 ++++++++ .../_internal/analyzer/unary_plan_node.py | 12 ++++++-- tests/integ/test_materialization_suite.py | 30 +++++++++++-------- 7 files changed, 57 insertions(+), 21 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 0ec1b7b1f97..d3a376fa501 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -927,7 +927,7 @@ def do_resolve_with_resolved_children( ) if isinstance(logical_plan, UnresolvedRelation): - return self.plan_builder.table(logical_plan.name) + return self.plan_builder.table(logical_plan.name, logical_plan) if isinstance(logical_plan, SnowflakeCreateTable): return self.plan_builder.save_as_table( diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index 4c56f53fcf7..f1380871728 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -192,7 +192,10 @@ def sql(self) -> str: def individual_query_complexity(self) -> int: # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond estimate = 3 - if isinstance(self.join_type, UsingJoin): + if ( + isinstance(self.join_type, UsingJoin) + and len(self.join_type.using_columns) > 0 + ): estimate += 1 + len(self.join_type.using_columns) estimate += ( self.join_condition.expression_complexity if self.join_condition else 0 diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index ba810e8024b..d9685db4f65 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -352,6 +352,12 @@ def plan_height(self) -> int: def num_duplicate_nodes(self) -> int: return len(find_duplicate_subtrees(self)) + @property + def individual_query_complexity(self) -> int: + if self.source_plan: + return self.source_plan.individual_query_complexity + return 0 + @property def subtree_query_complexity(self) -> int: if self._subtree_query_complexity is None: @@ -593,8 +599,8 @@ def large_local_relation_plan( source_plan=source_plan, ) - def table(self, table_name: str) -> SnowflakePlan: - return self.query(project_statement([], table_name), None) + def table(self, table_name: str, source_plan: LogicalPlan) -> SnowflakePlan: + return self.query(project_statement([], table_name), source_plan) def file_operation_plan( self, command: str, file_name: str, stage_location: str, options: Dict[str, str] diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 1bc5c056b5a..9fdaded6c08 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -26,7 +26,7 @@ def __init__(self) -> None: @property def individual_query_complexity(self) -> int: - return 1 + return 0 class LeafNode(LogicalPlan): @@ -54,6 +54,11 @@ def __init__(self, name: str) -> None: super().__init__() self.name = name + @property + def individual_query_complexity(self) -> int: + # SELECT * FROM name + return 1 + class SnowflakeValues(LeafNode): def __init__( diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e8f5ebcd2c1..db8117c1ed9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -33,6 +33,10 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + @property + def expression_complexity(self) -> int: + return sum(expr.expression_complexity for expr in self.children) + class Cast(UnaryExpression): sql_operator = "CAST" @@ -80,6 +84,10 @@ def __init__(self, child: Expression, name: str) -> None: def __str__(self): return f"{self.child} {self.sql_operator} {self.name}" + @property + def expression_complexity(self) -> int: + return sum(expr.expression_complexity for expr in self.children) + class UnresolvedAlias(UnaryExpression, NamedExpression): sql_operator = "AS" @@ -88,3 +96,7 @@ class UnresolvedAlias(UnaryExpression, NamedExpression): def __init__(self, child: Expression) -> None: super().__init__(child) self.name = child.sql + + @property + def expression_complexity(self) -> int: + return sum(expr.expression_complexity for expr in self.children) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index c4681afd695..5034e7a55ab 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -64,9 +64,15 @@ def __init__( @property def individual_query_complexity(self) -> int: - return sum( - expr.expression_complexity for expr in self.grouping_expressions - ) + sum(expr.expression_complexity for expr in self.aggregate_expressions) + # grouping estimate + estimate = max( + 1, sum(expr.expression_complexity for expr in self.grouping_expressions) + ) + # aggregate estimate + estimate += sum( + expr.expression_complexity for expr in self.aggregate_expressions + ) + return estimate class Pivot(UnaryNode): diff --git a/tests/integ/test_materialization_suite.py b/tests/integ/test_materialization_suite.py index f53f0d51e28..ba6dbd97811 100644 --- a/tests/integ/test_materialization_suite.py +++ b/tests/integ/test_materialization_suite.py @@ -39,7 +39,7 @@ def sample_table(session): def get_subtree_query_complexity(df: DataFrame) -> int: - return df._select_statement.subtree_query_complexity + return df._plan.subtree_query_complexity def assert_df_subtree_query_complexity(df: DataFrame, estimate: int): @@ -129,7 +129,7 @@ def test_agg(session: Session, sample_table: str): # SELECT (avg("A") + 1 :: INT) AS "ADD(AVG(A), LITERAL())" FROM ( SELECT * FROM sample_table) LIMIT 1 assert_df_subtree_query_complexity(df2, 5) # SELECT avg("A") AS "AVG(A)", avg(('b' + 1 :: INT)) AS "AVG_B" FROM ( SELECT * FROM sample_table) LIMIT 1 - assert_df_subtree_query_complexity(df3, 7) + assert_df_subtree_query_complexity(df3, 6) def test_window_function(session: Session): @@ -145,31 +145,35 @@ def test_window_function(session: Session): df1 = df.select(avg("value").over(window1).as_("window1")) # SELECT avg("VALUE") OVER (PARTITION BY "VALUE" ORDER BY "KEY" ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING ) AS "WINDOW1" FROM ( base_df) - assert_df_subtree_query_complexity(df1, get_subtree_query_complexity(df) + 10) + assert_df_subtree_query_complexity(df1, get_subtree_query_complexity(df) + 9) # SELECT avg("VALUE") OVER ( ORDER BY "KEY" DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS "WINDOW2" FROM ( base df) df2 = df1.select(avg("value").over(window2).as_("window2")) - assert_df_subtree_query_complexity(df2, get_subtree_query_complexity(df1) + 10) + assert_df_subtree_query_complexity(df2, get_subtree_query_complexity(df1) + 9) def test_join_statement(session: Session, sample_table: str): + # SELECT * FROM table df1 = session.table(sample_table) + assert_df_subtree_query_complexity(df1, 1) + # SELECT A, B, E FROM (SELECT $1 AS "A", $2 AS "B", $3 AS "E" FROM VALUES (1 :: INT, 2 :: INT, 5 :: INT), (3 :: INT, 4 :: INT, 9 :: INT)) df2 = session.create_dataframe([[1, 2, 5], [3, 4, 9]], schema=["a", "b", "e"]) - base_complexity_sum = ( - df1._plan.subtree_query_complexity + df2._plan.subtree_query_complexity - ) + assert_df_subtree_query_complexity(df2, 12) df3 = df1.join(df2) - # SELECT * FROM (df1 AS SNOWPARK_LEFT INNER JOIN df2 AS SNOWPARK_RIGHT) - assert_df_subtree_query_complexity(df3, base_complexity_sum + 5) + # +3 SELECT * FROM ((ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT) + # +4 ch1 = SELECT "A" AS "l_p8bm_A", "B" AS "l_p8bm_B", "C" AS "C", "D" AS "D" FROM (df1) + # +0 ch2 = SELECT "A" AS "r_2og4_A", "B" AS "r_2og4_B", "E" AS "E" FROM (SELECT $1 AS "A", $2 AS "B", $3 AS "E" FROM VALUES ()) + # ch2 is a re-write flattened version of df2 with aliases + assert_df_subtree_query_complexity(df3, 13 + 7) df4 = df1.join(df2, on=((df1["a"] == df2["a"]) & (df1["b"] == df2["b"]))) - # SELECT * FROM (df1 AS SNOWPARK_LEFT INNER JOIN df2 AS SNOWPARK_RIGHT ON (("l_xxhj_A" = "r_d4df_A") AND ("l_xxhj_B" = "r_d4df_B"))) - assert_df_subtree_query_complexity(df4, base_complexity_sum + 11) + # SELECT * FROM ((ch1) AS SNOWPARK_LEFT INNER JOIN ( ch2) AS SNOWPARK_RIGHT ON (("l_k7b8_A" = "r_e09m_A") AND ("l_k7b8_B" = "r_e09m_B"))) + assert_df_subtree_query_complexity(df4, get_subtree_query_complexity(df3) + 7) df5 = df1.join(df2, using_columns=["a", "b"]) - # SELECT * FROM (df1 AS SNOWPARK_LEFT INNER JOIN df2 AS SNOWPARK_RIGHT USING (a, b)) - assert_df_subtree_query_complexity(df5, base_complexity_sum + 7) + # SELECT * FROM ( (ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT USING (a, b)) + assert_df_subtree_query_complexity(df5, get_subtree_query_complexity(df3) + 3) def test_pivot_and_unpivot(session: Session): From 1e3a34c20f12e2d0a7a86dbef585666345245a82 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 30 May 2024 07:50:31 -0700 Subject: [PATCH 06/37] fix test --- src/snowflake/snowpark/_internal/telemetry.py | 2 +- tests/integ/scala/test_snowflake_plan_suite.py | 2 +- tests/integ/test_materialization_suite.py | 8 ++++++++ tests/integ/test_telemetry.py | 7 +++++++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index c5a313d2707..65ddb19b244 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -162,7 +162,7 @@ def wrap(*args, **kwargs): TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value ] = plan.num_duplicate_nodes api_calls[0][ - TelemetryField.QUERY_PLAN_COMPLEXITY_ESTIMATE + TelemetryField.QUERY_PLAN_COMPLEXITY_ESTIMATE.value ] = plan.subtree_query_complexity except Exception: pass diff --git a/tests/integ/scala/test_snowflake_plan_suite.py b/tests/integ/scala/test_snowflake_plan_suite.py index e9d9e845fb3..35610cb66cb 100644 --- a/tests/integ/scala/test_snowflake_plan_suite.py +++ b/tests/integ/scala/test_snowflake_plan_suite.py @@ -53,7 +53,7 @@ def test_single_query(session): # build plan plans = session._plan_builder - table_plan = plans.table(table_name) + table_plan = plans.table(table_name, df._plan.source_plan) project = plans.project(["num"], table_plan, None) assert len(project.queries) == 1 diff --git a/tests/integ/test_materialization_suite.py b/tests/integ/test_materialization_suite.py index ba6dbd97811..762d83fba2d 100644 --- a/tests/integ/test_materialization_suite.py +++ b/tests/integ/test_materialization_suite.py @@ -16,6 +16,14 @@ from snowflake.snowpark.window import Window from tests.utils import Utils +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="Breaking down queries is done for SQL translation", + run=False, + ) +] + @pytest.fixture(autouse=True) def setup(session): diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index 559ae98b15f..4c5737ae042 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -593,6 +593,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, + "query_plan_complexity_estimate": 14, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -606,6 +607,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, + "query_plan_complexity_estimate": 14, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -619,6 +621,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, + "query_plan_complexity_estimate": 14, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -632,6 +635,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, + "query_plan_complexity_estimate": 14, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -645,6 +649,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, + "query_plan_complexity_estimate": 14, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -781,6 +786,7 @@ def test_dataframe_stat_functions_api_calls(session): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, + "query_plan_complexity_estimate": 54, }, { "name": "DataFrameStatFunctions.crosstab", @@ -798,6 +804,7 @@ def test_dataframe_stat_functions_api_calls(session): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, + "query_plan_complexity_estimate": 54, } ] From 46d8b9f4a305adbe26adc81707b7cbfb7bc426aa Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 30 May 2024 09:15:21 -0700 Subject: [PATCH 07/37] fix typing issues --- src/snowflake/snowpark/_internal/analyzer/expression.py | 6 ++---- src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py | 2 +- .../snowpark/_internal/analyzer/snowflake_plan_node.py | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index e00e47de1c1..5ed60b103e4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -454,13 +454,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @cached_property def expression_complexity(self) -> int: - estimate = sum[ - ( + estimate = sum( condition.expression_complexity + value.expression_complexity for condition, value in self.branches ) - ] - estimate += self.else_value if self.else_value else 0 + estimate += self.else_value.expression_complexity if self.else_value else 0 return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index d9685db4f65..2e3cd01e4ed 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -233,7 +233,7 @@ def __init__( self.placeholder_query = placeholder_query # encode an id for CTE optimization self._id = encode_id(queries[-1].sql, queries[-1].params) - self._subtree_query_complexity = None + self._subtree_query_complexity: Optional[int] = None def __eq__(self, other: "SnowflakePlan") -> bool: if self._id is not None and other._id is not None: diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 9fdaded6c08..a773523e2b8 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -113,7 +113,7 @@ def __init__( def individual_query_complexity(self) -> int: estimate = 1 # mode is always present # column estimate - estimate += 0 if self.column_names else len(self.column_names) + estimate += sum(1 for _ in self.column_names) if self.column_names else 0 # clustering exprs estimate += sum(expr.expression_complexity for expr in self.clustering_exprs) # comment estimate @@ -179,7 +179,7 @@ def individual_query_complexity(self) -> int: estimate = len(self.column_names) if self.column_names else 0 # for transformations estimate += ( - len(expr.expression_complexity for expr in self.transformations) + sum(expr.expression_complexity for expr in self.transformations) if self.transformations else 0 ) From 88c7e7c2bc22a984d564acdce6bab17a4ae1e5d7 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 4 Jun 2024 10:48:06 -0700 Subject: [PATCH 08/37] use new approach --- .../_internal/analyzer/binary_expression.py | 8 +- .../_internal/analyzer/binary_plan_node.py | 32 +- .../_internal/analyzer/complexity_stat.py | 24 ++ .../snowpark/_internal/analyzer/cte_utils.py | 12 - .../snowpark/_internal/analyzer/expression.py | 208 ++++++++--- .../_internal/analyzer/grouping_set.py | 22 +- .../_internal/analyzer/select_statement.py | 98 +++-- .../_internal/analyzer/snowflake_plan.py | 30 +- .../_internal/analyzer/snowflake_plan_node.py | 110 +++--- .../_internal/analyzer/sort_expression.py | 14 +- .../_internal/analyzer/table_function.py | 106 ++++-- .../analyzer/table_merge_expression.py | 85 +++-- .../_internal/analyzer/unary_expression.py | 17 +- .../_internal/analyzer/unary_plan_node.py | 141 +++++--- .../_internal/analyzer/window_expression.py | 89 ++++- .../snowpark/_internal/server_connection.py | 2 +- src/snowflake/snowpark/_internal/telemetry.py | 8 +- tests/integ/test_materialization_suite.py | 342 ++++++++++++++---- tests/integ/test_telemetry.py | 49 ++- 19 files changed, 976 insertions(+), 421 deletions(-) create mode 100644 src/snowflake/snowpark/_internal/analyzer/complexity_stat.py diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 1e197400e83..41c2fa32fcf 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,8 +2,10 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from collections import Counter +from typing import AbstractSet, Dict, Optional +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, @@ -26,6 +28,10 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + class BinaryArithmeticExpression(BinaryExpression): pass diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index f1380871728..10d0e75502b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -2,8 +2,10 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import List, Optional +from collections import Counter +from typing import Dict, List, Optional +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages @@ -69,7 +71,10 @@ def __init__(self, left: LogicalPlan, right: LogicalPlan) -> None: class SetOperation(BinaryNode): - pass + @property + def individual_complexity_stat(self) -> Dict[str, int]: + # (left) operator (right) + return Counter({ComplexityStat.SET_OPERATION.value: 1}) class Except(SetOperation): @@ -189,18 +194,21 @@ def sql(self) -> str: return self.join_type.sql @property - def individual_query_complexity(self) -> int: - # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond - estimate = 3 - if ( - isinstance(self.join_type, UsingJoin) - and len(self.join_type.using_columns) > 0 - ): - estimate += 1 + len(self.join_type.using_columns) + def individual_complexity_stat(self) -> Dict[str, int]: + # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond + estimate = Counter({ComplexityStat.JOIN.value: 1}) + if isinstance(self.join_type, UsingJoin) and self.join_type.using_columns: + estimate += Counter( + {ComplexityStat.COLUMN.value: len(self.join_type.using_columns)} + ) estimate += ( - self.join_condition.expression_complexity if self.join_condition else 0 + self.join_condition.cumulative_complexity_stat + if self.join_condition + else Counter() ) estimate += ( - self.match_condition.expression_complexity if self.match_condition else 0 + self.match_condition.cumulative_complexity_stat + if self.match_condition + else Counter() ) return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/complexity_stat.py b/src/snowflake/snowpark/_internal/analyzer/complexity_stat.py new file mode 100644 index 00000000000..975506e7c2a --- /dev/null +++ b/src/snowflake/snowpark/_internal/analyzer/complexity_stat.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from enum import Enum + + +class ComplexityStat(Enum): + FILTER = "filter" + ORDER_BY = "order_by" + JOIN = "join" + SET_OPERATION = "set_operation" # UNION, EXCEPT, INTERSECT, UNION ALL + SAMPLE = "sample" + PIVOT = "pivot" + UNPIVOT = "unpivot" + WINDOW = "window" + GROUP_BY = "group_by" + PARTITION_BY = "partition_by" + CASE_WHEN = "case_when" + LITERAL = "literal" + COLUMN = "column" + FUNCTION = "function" + IN = "in" + LOW_IMPACT = "low_impact" diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py index cd502dbef15..4c4c50e899f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -182,15 +182,3 @@ def encode_id( except Exception as ex: logging.warning(f"Encode SnowflakePlan ID failed: {ex}") return None - - -def compute_subtree_query_complexity(node: "TreeNode") -> int: - current_level = [node] - estimate = 0 - while current_level: - next_level = [] - for node in current_level: - estimate += node.individual_query_complexity - next_level.extend(node.children_plan_nodes) - current_level = next_level - return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 5ed60b103e4..86ab229c6e1 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -4,10 +4,12 @@ import copy import uuid +from collections import Counter from functools import cached_property -from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, AbstractSet, Any, Dict, List, Optional, Tuple import snowflake.snowpark._internal.utils +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat if TYPE_CHECKING: from snowflake.snowpark._internal.analyzer.snowflake_plan import ( @@ -81,9 +83,17 @@ def sql(self) -> str: ) return f"{self.pretty_name}({children_sql})" + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({}) + @cached_property - def expression_complexity(self) -> int: - return 1 + sum(expr.expression_complexity for expr in (self.children or [])) + def cumulative_complexity_stat(self) -> Dict[str, int]: + children = self.children or [] + return sum( + (child.cumulative_complexity_stat for child in children), + self.individual_complexity_stat, + ) def __str__(self) -> str: return self.pretty_name @@ -104,10 +114,6 @@ def __copy__(self): new._expr_id = None # type: ignore return new - @property - def expression_complexity(self) -> int: - return 1 - class ScalarSubquery(Expression): def __init__(self, plan: "SnowflakePlan") -> None: @@ -118,9 +124,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR @cached_property - def expression_complexity(self) -> int: - # get plan complexity - return self.plan.subtree_query_complexity + def cumulative_complexity_stat(self) -> Dict[str, int]: + return self.plan.cumulative_complexity_stat + self.individual_complexity_stat class MultipleExpression(Expression): @@ -132,8 +137,14 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) @cached_property - def expression_complexity(self) -> int: - return sum(expr.expression_complexity for expr in self.expressions) + def cumulative_complexity_stat(self) -> Dict[str, int]: + return ( + sum( + (expr.cumulative_complexity_stat for expr in self.expressions), + Counter({}), + ) + + self.individual_complexity_stat + ) class InExpression(Expression): @@ -145,10 +156,19 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.IN.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return self.columns.expression_complexity + sum( - expr.expression_complexity for expr in self.values + def cumulative_complexity_stat(self) -> Dict[str, int]: + return ( + self.columns.cumulative_complexity_stat + + self.individual_complexity_stat + + sum( + (expr.cumulative_complexity_stat for expr in self.values), + Counter({}), + ) ) @@ -180,8 +200,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} @property - def expression_complexity(self) -> int: - return 1 + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.COLUMN.value: 1}) class Star(Expression): @@ -195,9 +215,19 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + if self.expressions: + return Counter({}) + # if there are no expressions, we assign column value = 1 to Star + return Counter({ComplexityStat.COLUMN.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return max(1, sum(expr.expression_complexity for expr in self.expressions)) + def cumulative_complexity_stat(self) -> Dict[str, int]: + return self.individual_complexity_stat + sum( + (child.individual_complexity_stat for child in self.expressions), + Counter({}), + ) class UnresolvedAttribute(Expression, NamedExpression): @@ -233,6 +263,10 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.COLUMN.value: 1}) + class Literal(Expression): def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None: @@ -256,6 +290,10 @@ def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None: else: self.datatype = infer_type(value) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.LITERAL.value: 1}) + class Interval(Expression): def __init__( @@ -304,9 +342,14 @@ def sql(self) -> str: def __str__(self) -> str: return self.sql - @cached_property - def expression_complexity(self) -> int: - return len(self.values_dict) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter( + { + ComplexityStat.LITERAL.value: 2 * len(self.values_dict), + ComplexityStat.LOW_IMPACT.value: 1, + } + ) class Like(Expression): @@ -318,9 +361,18 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + # expr LIKE pattern + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return self.expr.expression_complexity + self.pattern.expression_complexity + def cumulative_complexity_stat(self) -> Dict[str, int]: + return ( + self.expr.cumulative_complexity_stat + + self.pattern.cumulative_complexity_stat + + self.individual_complexity_stat + ) class RegExp(Expression): @@ -332,9 +384,18 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + # expr REG_EXP pattern + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return self.expr.expression_complexity + self.pattern.expression_complexity + def cumulative_complexity_stat(self) -> Dict[str, int]: + return ( + self.expr.cumulative_complexity_stat + + self.pattern.cumulative_complexity_stat + + self.individual_complexity_stat + ) class Collate(Expression): @@ -346,9 +407,14 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + # expr COLLATE collate_spec + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return self.expr.expression_complexity + 1 + def cumulative_complexity_stat(self) -> Dict[str, int]: + return self.expr.cumulative_complexity_stat + self.individual_complexity_stat class SubfieldString(Expression): @@ -360,9 +426,15 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + # the literal corresponds to the contribution from self.field + return Counter({ComplexityStat.LITERAL.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return self.expr.expression_complexity + 1 + def cumulative_complexity_stat(self) -> Dict[str, int]: + # self.expr ( self.field ) + return self.expr.cumulative_complexity_stat + self.individual_complexity_stat class SubfieldInt(Expression): @@ -374,9 +446,15 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + # the literal corresponds to the contribution from self.field + return Counter({ComplexityStat.LITERAL.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return self.expr.expression_complexity + 1 + def cumulative_complexity_stat(self) -> Dict[str, int]: + # self.expr ( self.field ) + return self.expr.cumulative_complexity_stat + self.individual_complexity_stat class FunctionExpression(Expression): @@ -410,11 +488,9 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) - @cached_property - def expression_complexity(self) -> int: - estimate = sum(expr.expression_complexity for expr in self.children) - estimate += 1 if self.is_distinct else 0 - return estimate + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.FUNCTION.value: 1}) class WithinGroup(Expression): @@ -427,10 +503,20 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + # expr WITHIN GROUP (ORDER BY cols) + return Counter({ComplexityStat.ORDER_BY.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return self.expr.expression_complexity + sum( - expr.expression_complexity for expr in self.order_by_cols + def cumulative_complexity_stat(self) -> Dict[str, int]: + return ( + sum( + (col.cumulative_complexity_stat for col in self.order_by_cols), + Counter({}), + ) + + self.individual_complexity_stat + + self.expr.cumulative_complexity_stat ) @@ -452,13 +538,24 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: exps.append(self.else_value) return derive_dependent_columns(*exps) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.CASE_WHEN.value: 1}) + @cached_property - def expression_complexity(self) -> int: - estimate = sum( - condition.expression_complexity + value.expression_complexity + def cumulative_complexity_stat(self) -> Dict[str, int]: + estimate = self.individual_complexity_stat + sum( + ( + condition.cumulative_complexity_stat + value.cumulative_complexity_stat for condition, value in self.branches - ) - estimate += self.else_value.expression_complexity if self.else_value else 0 + ), + Counter({}), + ) + estimate += ( + self.else_value.cumulative_complexity_stat + if self.else_value + else Counter({}) + ) return estimate @@ -481,9 +578,16 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.FUNCTION.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return 1 + sum(expr.expression_complexity for expr in self.children) + def cumulative_complexity_stat(self) -> Dict[str, int]: + return sum( + (expr.cumulative_complexity_stat for expr in self.children), + self.individual_complexity_stat, + ) class ListAgg(Expression): @@ -496,8 +600,10 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.FUNCTION.value: 1}) + @cached_property - def expression_complexity(self) -> int: - estimate = self.col.expression_complexity + 1 - estimate += 1 if self.is_distinct else 0 - return estimate + def cumulative_complexity_stat(self) -> Dict[str, int]: + return self.col.cumulative_complexity_stat + self.individual_complexity_stat diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 1f323563fec..380e1ce9509 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -2,9 +2,11 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from collections import Counter from functools import cached_property -from typing import AbstractSet, List, Optional +from typing import AbstractSet, Dict, List, Optional +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, @@ -20,6 +22,10 @@ def __init__(self, group_by_exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + class Cube(GroupingSet): pass @@ -39,5 +45,15 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*flattened_args) @cached_property - def expression_complexity(self) -> int: - return sum(sum(expr.expression_complexity for expr in arg) for arg in self.args) + def cumulative_complexity_stat(self) -> Dict[str, int]: + return sum( + ( + sum((expr.cumulative_complexity_stat for expr in arg), Counter()) + for arg in self.args + ), + self.individual_complexity_stat, + ) + + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index e92ebe6644c..5707369c4e2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -3,7 +3,7 @@ # from abc import ABC, abstractmethod -from collections import UserDict, defaultdict +from collections import Counter, UserDict, defaultdict from copy import copy, deepcopy from enum import Enum from typing import ( @@ -20,10 +20,8 @@ ) import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.cte_utils import ( - compute_subtree_query_complexity, - encode_id, -) +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.cte_utils import encode_id from snowflake.snowpark._internal.analyzer.table_function import ( TableFunctionExpression, TableFunctionJoin, @@ -203,7 +201,7 @@ def __init__( str, Dict[str, str] ] = defaultdict(dict) self._api_calls = api_calls.copy() if api_calls is not None else None - self._subtree_query_complexity: Optional[int] = None + self._cumulative_complexity_stat: Optional[Dict[str, int]] = None def __eq__(self, other: "Selectable") -> bool: if self._id is not None and other._id is not None: @@ -295,30 +293,33 @@ def num_duplicate_nodes(self) -> int: return self.snowflake_plan.num_duplicate_nodes @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: """This is the query complexity estimate added by this Selectable node to the overall query plan. For default case, it is the number of active columns. Specific cases are handled in child classes with additional explanation. """ - return ( - self.snowflake_plan.source_plan.individual_query_complexity - if self.snowflake_plan.source_plan - else len(self.column_states.active_columns) - ) + if isinstance(self.snowflake_plan.source_plan, Selectable): + return Counter( + {ComplexityStat.COLUMN.value: len(self.column_states.active_columns)} + ) + return self.snowflake_plan.source_plan.individual_complexity_stat @property - def subtree_query_complexity(self) -> int: + def cumulative_complexity_stat(self) -> Dict[str, int]: """This is sum of individual query complexity estimates for all nodes within a query plan subtree. """ - if self._subtree_query_complexity is None: - self._subtree_query_complexity = compute_subtree_query_complexity(self) - return self._subtree_query_complexity + if self._cumulative_complexity_stat is None: + estimate = self.individual_complexity_stat + for node in self.children_plan_nodes: + estimate += node.cumulative_complexity_stat + self._cumulative_complexity_stat = estimate + return self._cumulative_complexity_stat - @subtree_query_complexity.setter - def subtree_query_complexity(self, value: int): - self._subtree_query_complexity = value + @cumulative_complexity_stat.setter + def cumulative_complexity_stat(self, value: Dict[str, int]): + self._cumulative_complexity_stat = value @property def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: @@ -379,9 +380,9 @@ def schema_query(self) -> str: return self.sql_query @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: # SELECT * FROM entity - return 1 + return Counter({ComplexityStat.COLUMN.value: 1}) @property def query_params(self) -> Optional[Sequence[Any]]: @@ -439,14 +440,16 @@ def schema_query(self) -> str: return self._schema_query @property - def individual_query_complexity(self): + def individual_complexity_stat(self) -> Dict[str, int]: if self.pre_actions: # having pre-actions implies we have a non-select query followed by a # SELECT * FROM table(result_scan(query_id)) statement - return 1 + return Counter({ComplexityStat.COLUMN.value: 1}) # no pre-action implies the best estimate we have is of # active columns - return len(self.column_states.active_columns) + return Counter( + {ComplexityStat.COLUMN.value: len(self.column_states.active_columns)} + ) def to_subqueryable(self) -> "SelectSQL": """Convert this SelectSQL to a new one that can be used as a subquery. Refer to __init__.""" @@ -704,27 +707,42 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return [self.from_] @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: + estimate = Counter() # projection component - estimate = ( - sum(expr.expression_complexity for expr in self.projection) + estimate += ( + sum( + (expr.cumulative_complexity_stat for expr in self.projection), Counter() + ) if self.projection - else 0 + else Counter() + ) + + # filter component - add +1 for WHERE clause and sum of expression complexity for where expression + estimate += ( + Counter({ComplexityStat.FILTER.value: 1}) + + self.where.cumulative_complexity_stat + if self.where + else Counter() ) - # order by component - add complexity for each sort expression but remove len(order_by) - 1 since we only - # include "ORDER BY" once in sql test + + # order by component - add complexity for each sort expression estimate += ( - ( - sum(expr.expression_complexity for expr in self.order_by) - - (len(self.order_by) - 1) + sum( + (expr.cumulative_complexity_stat for expr in self.order_by), + Counter({ComplexityStat.ORDER_BY.value: 1}), ) if self.order_by - else 0 + else Counter() + ) + + # limit/offset component + estimate += ( + Counter({ComplexityStat.LOW_IMPACT.value: 1}) if self.limit_ else Counter() + ) + estimate += ( + Counter({ComplexityStat.LOW_IMPACT.value: 1}) if self.offset else Counter() ) - # filter component - add +1 for WHERE clause and sum of expression complexity for where expression - estimate += (1 + self.where.expression_complexity) if self.where else 0 - # limit component - estimate += 1 if self.limit_ else 0 return estimate def to_subqueryable(self) -> "Selectable": @@ -1109,9 +1127,9 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return self._nodes @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: # we add #set_operands - 1 additional operators in sql query - return len(self.set_operands) - 1 + return Counter({ComplexityStat.SET_OPERATION.value: len(self.set_operands) - 1}) class DeriveColumnDependencyError(Exception): diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 2e36b22a4f8..ba5ada58054 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -6,7 +6,7 @@ import re import sys import uuid -from collections import defaultdict +from collections import Counter, defaultdict from functools import cached_property from typing import ( TYPE_CHECKING, @@ -77,7 +77,6 @@ SetOperation, ) from snowflake.snowpark._internal.analyzer.cte_utils import ( - compute_subtree_query_complexity, create_cte_query, encode_id, find_duplicate_subtrees, @@ -233,7 +232,7 @@ def __init__( self.placeholder_query = placeholder_query # encode an id for CTE optimization self._id = encode_id(queries[-1].sql, queries[-1].params) - self._subtree_query_complexity: Optional[int] = None + self._cumulative_complexity_stat: Optional[Dict[str, int]] = None def __eq__(self, other: "SnowflakePlan") -> bool: if self._id is not None and other._id is not None: @@ -353,20 +352,23 @@ def num_duplicate_nodes(self) -> int: return len(find_duplicate_subtrees(self)) @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: if self.source_plan: - return self.source_plan.individual_query_complexity - return 0 + return self.source_plan.individual_complexity_stat + return Counter() @property - def subtree_query_complexity(self) -> int: - if self._subtree_query_complexity is None: - self._subtree_query_complexity = compute_subtree_query_complexity(self) - return self._subtree_query_complexity - - @subtree_query_complexity.setter - def subtree_query_complexity(self, value: int): - self._subtree_query_complexity = value + def cumulative_complexity_stat(self) -> Dict[str, int]: + if self._cumulative_complexity_stat is None: + estimate = self.individual_complexity_stat + for node in self.children_plan_nodes: + estimate += node.cumulative_complexity_stat + self._cumulative_complexity_stat = estimate + return self._cumulative_complexity_stat + + @cumulative_complexity_stat.setter + def cumulative_complexity_stat(self, value: Dict[str, int]): + self._cumulative_complexity_stat = value def __copy__(self) -> "SnowflakePlan": if self.session._cte_optimization_enabled: diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index a773523e2b8..2038f12a169 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -4,9 +4,11 @@ # import sys +from collections import Counter from enum import Enum from typing import Any, Dict, List, Optional +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType @@ -23,10 +25,25 @@ class LogicalPlan: def __init__(self) -> None: self.children = [] + self._cumulative_complexity_stat: Optional[Dict[str, int]] = None @property - def individual_query_complexity(self) -> int: - return 0 + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter() + + @property + def cumulative_complexity_stat(self) -> Dict[str, int]: + if self._cumulative_complexity_stat is None: + estimate = self.individual_complexity_stat + for node in self.children: + estimate += node.cumulative_complexity_stat + + self._cumulative_complexity_stat = estimate + return self._cumulative_complexity_stat + + @cumulative_complexity_stat.setter + def cumulative_complexity_stat(self, value: Dict[str, int]): + self._cumulative_complexity_stat = value class LeafNode(LogicalPlan): @@ -44,9 +61,17 @@ def __init__(self, start: int, end: int, step: int, num_slices: int = 1) -> None self.num_slices = num_slices @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (step) + (start) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => count))) - return 6 + return Counter( + { + ComplexityStat.WINDOW.value: 1, + ComplexityStat.ORDER_BY.value: 1, + ComplexityStat.LITERAL.value: 3, # step, start, count + ComplexityStat.COLUMN.value: 1, # id column + ComplexityStat.LOW_IMPACT.value: 2, # ROW_NUMBER, GENERATOR + } + ) class UnresolvedRelation(LeafNode): @@ -55,9 +80,9 @@ def __init__(self, name: str) -> None: self.name = name @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: # SELECT * FROM name - return 1 + return Counter({ComplexityStat.COLUMN.value: 1}) class SnowflakeValues(LeafNode): @@ -73,11 +98,15 @@ def __init__( self.schema_query = schema_query @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) - # (n+1) * m # TODO: use ARRAY_BIND_THRESHOLD - return (len(self.data) + 1) * len(self.output) + return Counter( + { + ComplexityStat.COLUMN.value: len(self.output), + ComplexityStat.LITERAL.value: len(self.data) * len(self.output), + } + ) class SaveMode(Enum): @@ -92,7 +121,7 @@ class SnowflakeCreateTable(LogicalPlan): def __init__( self, table_name: Iterable[str], - column_names: Optional[Iterable[str]], + column_names: Optional[List[str]], mode: SaveMode, query: Optional[LogicalPlan], table_type: str = "", @@ -110,14 +139,16 @@ def __init__( self.comment = comment @property - def individual_query_complexity(self) -> int: - estimate = 1 # mode is always present - # column estimate - estimate += sum(1 for _ in self.column_names) if self.column_names else 0 - # clustering exprs - estimate += sum(expr.expression_complexity for expr in self.clustering_exprs) - # comment estimate - estimate += 0 if self.comment else 1 + def individual_complexity_stat(self) -> Dict[str, int]: + # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (child) + estimate = Counter( + {ComplexityStat.LOW_IMPACT.value: 1, ComplexityStat.COLUMN.value: 1} + ) + estimate += ( + sum(expr.cumulative_complexity_stat for expr in self.clustering_exprs) + if self.clustering_exprs + else Counter() + ) return estimate @@ -132,11 +163,12 @@ def __init__( self.children.append(child) @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: # for limit and offset return ( - self.limit_expr.expression_complexity - + self.offset_expr.expression_complexity + Counter({ComplexityStat.LOW_IMPACT.value: 2}) + + self.limit_expr.cumulative_complexity_stat + + self.offset_expr.cumulative_complexity_stat ) @@ -173,24 +205,6 @@ def __init__( self.cur_options = cur_options self.create_table_from_infer_schema = create_table_from_infer_schema - @property - def individual_query_complexity(self) -> int: - # for columns - estimate = len(self.column_names) if self.column_names else 0 - # for transformations - estimate += ( - sum(expr.expression_complexity for expr in self.transformations) - if self.transformations - else 0 - ) - # for pattern - estimate += 1 if self.pattern else 0 - # for files - estimate += len(self.files) if self.files else 0 - # for copy options - estimate += len(self.copy_options) if self.copy_options else 0 - return estimate - class CopyIntoLocationNode(LogicalPlan): def __init__( @@ -215,21 +229,3 @@ def __init__( self.file_format_name = file_format_name self.file_format_type = file_format_type self.copy_options = copy_options - - @property - def individual_query_complexity(self) -> int: - # for stage location - estimate = 1 - # for partition - estimate += self.partition_by.expression_complexity if self.partition_by else 0 - # for file format name - estimate += 1 if self.file_format_name else 0 - # for file format type - estimate += 1 if self.file_format_type else 0 - # for file format options - estimate += len(self.format_type_options) if self.format_type_options else 0 - # for copy options - estimate += len(self.copy_options) - # for header - estimate += 1 if self.header else 0 - return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index e0c39f6e626..ef34a845e86 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,8 +2,9 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from collections import Counter from functools import cached_property -from typing import AbstractSet, Optional, Type +from typing import AbstractSet, Dict, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, @@ -57,9 +58,10 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter() + @cached_property - def expression_complexity(self) -> int: - # ORDER BY child [null ordering] - estimate = self.child.expression_complexity + 1 - estimate += 1 if self.null_ordering else 0 - return estimate + def cumulative_complexity_stat(self) -> Dict[str, int]: + return self.child.cumulative_complexity_stat + self.individual_complexity_stat diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 0cac7e62a87..719912c2dfd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -3,9 +3,11 @@ # import sys +from collections import Counter from functools import cached_property from typing import Dict, List, Optional +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -32,19 +34,25 @@ def __init__( self.order_spec = order_spec @cached_property - def expression_complexity(self) -> int: + def cumulative_complexity_stat(self) -> Dict[str, int]: if not self.over: - return 0 - - estimate = ( - sum(expr.expression_complexity for expr in self.partition_spec) + return Counter() + estimate = Counter({ComplexityStat.WINDOW.value: 1}) + estimate += ( + sum( + (expr.cumulative_complexity_stat for expr in self.partition_spec), + Counter({ComplexityStat.PARTITION_BY.value: 1}), + ) if self.partition_spec - else 0 + else Counter() ) estimate += ( - sum(expr.expression_complexity for expr in self.order_spec) + sum( + (expr.cumulative_complexity_stat for expr in self.order_spec), + Counter({ComplexityStat.ORDER_BY.value: 1}), + ) if self.order_spec - else 0 + else Counter() ) return estimate @@ -63,10 +71,17 @@ def __init__( self.aliases = aliases self.api_call_source = api_call_source + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.FUNCTION.value: 1}) + @cached_property - def expression_complexity(self) -> int: - return 1 + ( - self.partition_spec.expression_complexity if self.partition_spec else 0 + def cumulative_complexity_stat(self) -> Dict[str, int]: + return ( + self.partition_spec.cumulative_complexity_stat + + self.individual_complexity_stat + if self.partition_spec + else self.individual_complexity_stat ) @@ -82,8 +97,8 @@ def __init__( self.mode = mode @cached_property - def expression_complexity(self) -> int: - return self.input.expression_complexity + 4 + def cumulative_complexity_stat(self) -> Dict[str, int]: + return self.individual_complexity_stat + self.input.cumulative_complexity_stat class PosArgumentsTableFunction(TableFunctionExpression): @@ -97,10 +112,15 @@ def __init__( self.args = args @cached_property - def expression_complexity(self) -> int: - estimate = 1 + sum(expr.expression_complexity for expr in self.args) + def cumulative_complexity_stat(self) -> Dict[str, int]: + estimate = sum( + (arg.cumulative_complexity_stat for arg in self.args), + self.individual_complexity_stat, + ) estimate += ( - self.partition_spec.expression_complexity if self.partition_spec else 0 + self.partition_spec.cumulative_complexity_stat + if self.partition_spec + else Counter() ) return estimate @@ -116,12 +136,15 @@ def __init__( self.args = args @cached_property - def expression_complexity(self) -> int: - estimate = 1 + sum( - (1 + arg.expression_complexity) for arg in self.args.values() + def cumulative_complexity_stat(self) -> Dict[str, int]: + estimate = sum( + (arg.cumulative_complexity_stat for arg in self.args.values()), + self.individual_complexity_stat, ) estimate += ( - self.partition_spec.expression_complexity if self.partition_spec else 0 + self.partition_spec.cumulative_complexity_stat + if self.partition_spec + else Counter() ) return estimate @@ -133,12 +156,18 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: self.operators = operators @cached_property - def expression_complexity(self) -> int: - return ( - 1 - + sum(1 + arg.expression_complexity for arg in self.args.values()) - + len(self.operators) + def cumulative_complexity_stat(self) -> Dict[str, int]: + estimate = sum( + (arg.cumulative_complexity_stat for arg in self.args.values()), + self.individual_complexity_stat, + ) + estimate += ( + self.partition_spec.cumulative_complexity_stat + if self.partition_spec + else Counter() ) + estimate += Counter({ComplexityStat.COLUMN.value: len(self.operators)}) + return estimate class TableFunctionRelation(LogicalPlan): @@ -147,8 +176,9 @@ def __init__(self, table_function: TableFunctionExpression) -> None: self.table_function = table_function @property - def individual_query_complexity(self) -> int: - return self.table_function.expression_complexity + def individual_complexity_stat(self) -> Dict[str, int]: + # SELECT * FROM table_function + return self.table_function.cumulative_complexity_stat class TableFunctionJoin(LogicalPlan): @@ -166,11 +196,17 @@ def __init__( self.right_cols = right_cols if right_cols is not None else ["*"] @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: + # SELECT left_cols, right_cols FROM child as left_alias JOIN table(func(...)) as right_alias return ( - self.table_function.expression_complexity - + len(self.left_cols) - + len(self.right_cols) + Counter( + { + ComplexityStat.COLUMN.value: len(self.left_cols) + + len(self.right_cols), + ComplexityStat.JOIN.value: 1, + } + ) + + self.table_function.cumulative_complexity_stat ) @@ -183,5 +219,9 @@ def __init__( self.table_function = table_function @property - def individual_query_complexity(self) -> int: - return 1 + self.table_function.expression_complexity + def individual_complexity_stat(self) -> Dict[str, int]: + # SELECT * FROM (child), LATERAL table_func_expression + return ( + Counter({ComplexityStat.COLUMN.value: 1}) + + self.table_function.cumulative_complexity_stat + ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index baf3b24905b..aafd07d5ac7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -2,9 +2,11 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from collections import Counter from functools import cached_property from typing import Dict, List, Optional +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.snowflake_plan import ( LogicalPlan, @@ -17,11 +19,17 @@ def __init__(self, condition: Optional[Expression]) -> None: super().__init__() self.condition = condition + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + @cached_property - def expression_complexity(self) -> int: + def cumulative_complexity_stat(self) -> Dict[str, int]: # WHEN MATCHED [AND condition] THEN DEL - estimate = 4 - estimate += self.condition.expression_complexity if self.condition else 0 + estimate = self.individual_complexity_stat + estimate += ( + self.condition.cumulative_complexity_stat if self.condition else Counter() + ) return estimate @@ -33,13 +41,19 @@ def __init__( self.assignments = assignments @cached_property - def expression_complexity(self) -> int: + def cumulative_complexity_stat(self) -> Dict[str, int]: # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) - estimate = 4 - estimate += 1 if self.condition else 0 + estimate = self.individual_complexity_stat + estimate += ( + self.condition.cumulative_complexity_stat if self.condition else Counter() + ) estimate += sum( - key_expr.expression_complexity + val_expr.expression_complexity - for key_expr, val_expr in self.assignments.items() + ( + key_expr.cumulative_complexity_stat + + val_expr.cumulative_complexity_stat + for key_expr, val_expr in self.assignments.items() + ), + Counter(), ) return estimate @@ -60,12 +74,18 @@ def __init__( self.values = values @cached_property - def expression_complexity(self) -> int: + def cumulative_complexity_stat(self) -> Dict[str, int]: # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) - estimate = 5 - estimate += sum(expr.expression_complexity for expr in self.keys) - estimate += sum(expr.expression_complexity for expr in self.values) - estimate += self.condition.expression_complexity if self.condition else 0 + estimate = self.individual_complexity_stat + estimate += ( + self.condition.cumulative_complexity_stat if self.condition else Counter() + ) + estimate += sum( + (key.cumulative_complexity_stat for key in self.keys), Counter() + ) + estimate += sum( + (val.cumulative_complexity_stat for val in self.values), Counter() + ) return estimate @@ -84,16 +104,19 @@ def __init__( self.source_data = source_data self.children = [source_data] if source_data else [] - @cached_property - def individual_query_complexity(self) -> int: + @property + def individual_complexity_stat(self) -> Dict[str, int]: # UPDATE table_name SET COMMA.join(k, v in assignments) [source_data] [WHERE condition] - estimate = 2 - estimate += sum( - key_expr.expression_complexity + val_expr.expression_complexity - for key_expr, val_expr in self.assignments.items() + estimate = sum( + ( + k.cumulative_complexity_stat + v.cumulative_complexity_stat + for k, v in self.assignments.items() + ), + Counter(), + ) + estimate += ( + self.condition.cumulative_complexity_stat if self.condition else Counter() ) - estimate += self.condition.expression_complexity if self.condition else 0 - # note that source data will be handled by subtree aggregator since it is added as a child return estimate @@ -110,12 +133,12 @@ def __init__( self.source_data = source_data self.children = [source_data] if source_data else [] - @cached_property - def individual_query_complexity(self) -> int: + @property + def individual_complexity_stat(self) -> Dict[str, int]: # DELETE FROM table_name [USING source_data] [WHERE condition] - estimate = 2 - estimate += self.condition.expression_complexity if self.condition else 0 - return estimate + return ( + self.condition.cumulative_complexity_stat if self.condition else Counter() + ) class TableMerge(LogicalPlan): @@ -133,11 +156,9 @@ def __init__( self.clauses = clauses self.children = [source] if source else [] - @cached_property - def individual_query_complexity(self) -> int: + @property + def individual_complexity_stat(self) -> Dict[str, int]: # MERGE INTO table_name USING (source) ON join_expr clauses - return ( - 4 - + self.join_expr.expression_complexity - + sum(expr.expression_complexity for expr in self.clauses) + return self.join_expr.cumulative_complexity_stat + sum( + (clause.cumulative_complexity_stat for clause in self.clauses), Counter() ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index db8117c1ed9..e5082ca57ae 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,8 +2,10 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from collections import Counter +from typing import AbstractSet, Dict, Optional +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, @@ -34,8 +36,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) @property - def expression_complexity(self) -> int: - return sum(expr.expression_complexity for expr in self.children) + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) class Cast(UnaryExpression): @@ -85,8 +87,9 @@ def __str__(self): return f"{self.child} {self.sql_operator} {self.name}" @property - def expression_complexity(self) -> int: - return sum(expr.expression_complexity for expr in self.children) + def individual_complexity_stat(self) -> Dict[str, int]: + # child AS name + return Counter({ComplexityStat.COLUMN.value: 1}) class UnresolvedAlias(UnaryExpression, NamedExpression): @@ -98,5 +101,5 @@ def __init__(self, child: Expression) -> None: self.name = child.sql @property - def expression_complexity(self) -> int: - return sum(expr.expression_complexity for expr in self.children) + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter() diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 5034e7a55ab..542f58801c9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -2,8 +2,10 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from collections import Counter from typing import Dict, List, Optional, Union +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, @@ -34,10 +36,16 @@ def __init__( self.seed = seed @property - def individual_query_complexity(self) -> int: - # child SAMPLE (probability) -- if probability is provided - # child SAMPLE (row_count ROWS) -- if not probability but row count is provided - return 2 + (1 if self.row_count else 0) + def individual_complexity_stat(self) -> Dict[str, int]: + # SELECT * FROM (child) SAMPLE (probability) -- if probability is provided + # SELECT * FROM (child) SAMPLE (row_count ROWS) -- if not probability but row count is provided + return Counter( + { + ComplexityStat.SAMPLE.value: 1, + ComplexityStat.LITERAL.value: 1, + ComplexityStat.COLUMN.value: 1, + } + ) class Sort(UnaryNode): @@ -46,9 +54,11 @@ def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None: self.order = order @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: # child ORDER BY COMMA.join(order) - return 1 + sum(expr.expression_complexity for expr in self.order) + return Counter({ComplexityStat.ORDER_BY.value: 1}) + sum( + (col.cumulative_complexity_stat for col in self.order), Counter() + ) class Aggregate(UnaryNode): @@ -63,14 +73,26 @@ def __init__( self.aggregate_expressions = aggregate_expressions @property - def individual_query_complexity(self) -> int: - # grouping estimate - estimate = max( - 1, sum(expr.expression_complexity for expr in self.grouping_expressions) + def individual_complexity_stat(self) -> Dict[str, int]: + estimate = Counter() + if self.grouping_expressions: + # GROUP BY grouping_exprs + estimate += Counter({ComplexityStat.GROUP_BY.value: 1}) + sum( + (expr.cumulative_complexity_stat for expr in self.grouping_expressions), + Counter(), + ) + else: + # LIMIT 1 + estimate += Counter({ComplexityStat.LOW_IMPACT.value: 1}) + + get_complexity_stat = ( + lambda expr: expr.cumulative_complexity_stat + if hasattr(expr, "cumulative_complexity_stat") + else Counter({ComplexityStat.COLUMN.value: 1}) ) - # aggregate estimate estimate += sum( - expr.expression_complexity for expr in self.aggregate_expressions + (get_complexity_stat(expr) for expr in self.aggregate_expressions), + Counter(), ) return estimate @@ -93,26 +115,37 @@ def __init__( self.default_on_null = default_on_null @property - def individual_query_complexity(self) -> int: - # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) - estimate = 3 - estimate += sum(expr.expression_complexity for expr in self.grouping_columns) - estimate += self.pivot_column.expression_complexity + def individual_complexity_stat(self) -> Dict[str, int]: + estimate = Counter() + # child estimate adjustment if grouping cols + if self.grouping_columns and self.aggregates and self.aggregates[0].children: + # for additional projecting cols when grouping cols is not empty + estimate += sum( + (col.cumulative_complexity_stat for col in self.grouping_columns), + Counter(), + ) + estimate += self.pivot_column.cumulative_complexity_stat + estimate += self.aggregates[0].children[0].cumulative_complexity_stat + + # pivot col if isinstance(self.pivot_values, ScalarSubquery): - estimate += self.pivot_values.expression_complexity + estimate += self.pivot_values.cumulative_complexity_stat elif isinstance(self.pivot_values, List): - estimate += sum(expr.expression_complexity for expr in self.pivot_values) + estimate += sum( + (val.cumulative_complexity_stat for val in self.pivot_values), Counter() + ) else: - # when pivot values is None - estimate += 1 + # if pivot values is None, then we add LOW_IMPACT for ANY + estimate += Counter({ComplexityStat.LOW_IMPACT.value: 1}) - if len(self.aggregates) > 0: - estimate += self.aggregates[0].expression_complexity + # aggregate estimate + estimate += sum( + (expr.cumulative_complexity_stat for expr in self.aggregates), Counter() + ) - estimate += ( - self.default_on_null.expression_complexity + 1 - if self.default_on_null - else 0 + # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) + estimate += Counter( + {ComplexityStat.COLUMN.value: 2, ComplexityStat.PIVOT.value: 1} ) return estimate @@ -131,9 +164,15 @@ def __init__( self.column_list = column_list @property - def individual_query_complexity(self) -> int: + def individual_complexity_stat(self) -> Dict[str, int]: # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) - return 4 + sum(expr.expression_complexity for expr in self.column_list) + estimate = Counter( + {ComplexityStat.UNPIVOT.value: 1, ComplexityStat.COLUMN.value: 3} + ) + estimate += sum( + (expr.cumulative_complexity_stat for expr in self.column_list), Counter() + ) + return estimate class Rename(UnaryNode): @@ -146,8 +185,14 @@ def __init__( self.column_map = column_map @property - def individual_query_complexity(self) -> int: - return 2 * len(self.column_map) + def individual_complexity_stat(self) -> Dict[str, int]: + # SELECT * RENAME (before AS after, ...) FROM child + return Counter( + { + ComplexityStat.COLUMN.value: 1 + 2 * len(self.column_map), + ComplexityStat.LOW_IMPACT.value: 1 + len(self.column_map), + } + ) class Filter(UnaryNode): @@ -156,8 +201,12 @@ def __init__(self, condition: Expression, child: LogicalPlan) -> None: self.condition = condition @property - def individual_query_complexity(self) -> int: - return self.condition.expression_complexity + def individual_complexity_stat(self) -> Dict[str, int]: + # child WHERE condition + return ( + Counter({ComplexityStat.FILTER.value: 1}) + + self.condition.cumulative_complexity_stat + ) class Project(UnaryNode): @@ -166,8 +215,16 @@ def __init__(self, project_list: List[NamedExpression], child: LogicalPlan) -> N self.project_list = project_list @property - def individual_query_complexity(self) -> int: - return sum(expr.expression_complexity for expr in self.project_list) + def individual_complexity_stat(self) -> Dict[str, int]: + if not self.project_list: + return Counter({ComplexityStat.COLUMN.value: 1}) + + get_complexity_stat = ( + lambda col: col.cumulative_complexity_stat + if hasattr(col, "cumulative_complexity_stat") + else Counter({ComplexityStat.COLUMN.value: 1}) + ) + return sum((get_complexity_stat(col) for col in self.project_list), Counter()) class ViewType: @@ -196,13 +253,6 @@ def __init__( self.view_type = view_type self.comment = comment - @property - def individual_query_complexity(self) -> int: - estimate = 3 - estimate += 1 if isinstance(self.view_type, LocalTempView) else 0 - estimate += 1 if self.comment else 0 - return estimate - class CreateDynamicTableCommand(UnaryNode): def __init__( @@ -218,10 +268,3 @@ def __init__( self.warehouse = warehouse self.lag = lag self.comment = comment - - @property - def individual_query_complexity(self) -> int: - # CREATE OR REPLACE DYNAMIC TABLE name LAG = lag WAREHOUSE = wh [comment] AS child - estimate = 7 - estimate += 1 if self.comment else 0 - return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 056272c08af..5729b17ea9b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -2,9 +2,11 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from collections import Counter from functools import cached_property -from typing import AbstractSet, List, Optional +from typing import AbstractSet, Dict, List, Optional +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, @@ -18,6 +20,10 @@ class SpecialFrameBoundary(Expression): def __init__(self) -> None: super().__init__() + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + class UnboundedPreceding(SpecialFrameBoundary): sql = "UNBOUNDED PRECEDING" @@ -64,10 +70,18 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + @cached_property - def expression_complexity(self) -> int: + def cumulative_complexity_stat(self) -> Dict[str, int]: # frame_type BETWEEN lower AND upper - return 2 + self.lower.expression_complexity + self.upper.expression_complexity + return ( + self.individual_complexity_stat + + self.lower.cumulative_complexity_stat + + self.upper.cumulative_complexity_stat + ) class WindowSpecDefinition(Expression): @@ -87,22 +101,36 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) - @cached_property - def expression_complexity(self) -> int: - # PARTITION BY COMMA.join(exprs) ORDER BY frame_spec - estimate = self.frame_spec.expression_complexity + @property + def individual_complexity_stat(self) -> Dict[str, int]: + estimate = Counter() estimate += ( - (1 + sum(expr.expression_complexity for expr in self.partition_spec)) + Counter({ComplexityStat.PARTITION_BY.value: 1}) if self.partition_spec - else 0 + else Counter() ) estimate += ( - (1 + sum(expr.expression_complexity for expr in self.order_spec)) + Counter({ComplexityStat.ORDER_BY.value: 1}) if self.order_spec - else 0 + else Counter() ) return estimate + @cached_property + def cumulative_complexity_stat(self) -> Dict[str, int]: + # partition_spec order_by_spec frame_spec + return ( + self.individual_complexity_stat + + sum( + (expr.cumulative_complexity_stat for expr in self.partition_spec), + Counter(), + ) + + sum( + (expr.cumulative_complexity_stat for expr in self.order_spec), Counter() + ) + + self.frame_spec.cumulative_complexity_stat + ) + class WindowExpression(Expression): def __init__( @@ -115,12 +143,17 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + return Counter({ComplexityStat.WINDOW.value: 1}) + @cached_property - def expression_complexity(self) -> int: + def cumulative_complexity_stat(self) -> Dict[str, int]: # window_function OVER ( window_spec ) return ( - self.window_function.expression_complexity - + self.window_spec.expression_complexity + self.window_function.cumulative_complexity_stat + + self.window_spec.cumulative_complexity_stat + + self.individual_complexity_stat ) @@ -143,13 +176,31 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) + @property + def individual_complexity_stat(self) -> Dict[str, int]: + # for func_name + estimate = Counter({ComplexityStat.FUNCTION.value: 1}) + # for offset + estimate += ( + Counter({ComplexityStat.LITERAL.value: 1}) if self.offset else Counter() + ) + # for ignore nulls + estimate += ( + Counter({ComplexityStat.LOW_IMPACT.value: 1}) + if self.ignore_nulls + else Counter() + ) + return estimate + @cached_property - def expression_complexity(self) -> int: + def cumulative_complexity_stat(self) -> Dict[str, int]: # func_name (expr [, offset] [, default]) [IGNORE NULLS] - estimate = 1 + self.expr.expression_complexity - estimate += 1 if self.offset else 0 - estimate += self.default.expression_complexity if self.default else 0 - estimate += 1 if self.ignore_nulls else 0 + estimate = ( + self.individual_complexity_stat + self.expr.cumulative_complexity_stat + ) + estimate += ( + self.default.cumulative_complexity_stat if self.default else Counter() + ) return estimate diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 6eb10d56365..88c286055b0 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -656,7 +656,7 @@ def get_result_and_metadata( def get_result_query_id(self, plan: SnowflakePlan, **kwargs) -> str: # get the iterator such that the data is not fetched - result_set, _ = self.get_result_set(plan, to_iter=True, **kwargs) + result_set, _ = self.get_result_set(plan, ignore_results=True, **kwargs) return result_set["sfqid"] @_Decorator.wrap_exception diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 65ddb19b244..801af7d5524 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -68,7 +68,7 @@ class TelemetryField(Enum): # dataframe query stats QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" - QUERY_PLAN_COMPLEXITY_ESTIMATE = "query_plan_complexity_estimate" + QUERY_PLAN_COMPLEXITY_STAT = "query_plan_complexity_stat" # These DataFrame APIs call other DataFrame APIs @@ -161,9 +161,9 @@ def wrap(*args, **kwargs): api_calls[0][ TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value ] = plan.num_duplicate_nodes - api_calls[0][ - TelemetryField.QUERY_PLAN_COMPLEXITY_ESTIMATE.value - ] = plan.subtree_query_complexity + api_calls[0][TelemetryField.QUERY_PLAN_COMPLEXITY_STAT.value] = dict( + plan.cumulative_complexity_stat + ) except Exception: pass args[0]._session._conn._telemetry_client.send_function_usage_telemetry( diff --git a/tests/integ/test_materialization_suite.py b/tests/integ/test_materialization_suite.py index 762d83fba2d..1cc4679f281 100644 --- a/tests/integ/test_materialization_suite.py +++ b/tests/integ/test_materialization_suite.py @@ -2,8 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from collections import Counter +from typing import Dict + import pytest +from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.select_statement import ( SET_EXCEPT, SET_INTERSECT, @@ -46,48 +50,70 @@ def sample_table(session): Utils.drop_table(session, table_name) -def get_subtree_query_complexity(df: DataFrame) -> int: - return df._plan.subtree_query_complexity +def get_cumulative_complexity_stat(df: DataFrame) -> Dict[str, int]: + return df._plan.cumulative_complexity_stat -def assert_df_subtree_query_complexity(df: DataFrame, estimate: int): +def assert_df_subtree_query_complexity(df: DataFrame, estimate: Dict[str, int]): assert ( - get_subtree_query_complexity(df) == estimate + get_cumulative_complexity_stat(df) == estimate ), f"query = {df.queries['queries'][-1]}" def test_create_dataframe_from_values(session: Session): df1 = session.create_dataframe([[1], [2], [3]], schema=["a"]) # SELECT "A" FROM ( SELECT $1 AS "A" FROM VALUES (1 :: INT), (2 :: INT), (3 :: INT)) - assert_df_subtree_query_complexity(df1, 5) + assert_df_subtree_query_complexity( + df1, {ComplexityStat.LITERAL.value: 3, ComplexityStat.COLUMN.value: 2} + ) df2 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"]) # SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT), (5 :: INT, 6 :: INT)) - assert_df_subtree_query_complexity(df2, 10) + assert_df_subtree_query_complexity( + df2, {ComplexityStat.LITERAL.value: 6, ComplexityStat.COLUMN.value: 4} + ) def test_session_table(session: Session, sample_table: str): df = session.table(sample_table) # select * from sample_table - assert_df_subtree_query_complexity(df, 1) + assert_df_subtree_query_complexity(df, {ComplexityStat.COLUMN.value: 1}) def test_range_statement(session: Session): df = session.range(1, 5, 2) # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (2) + (1) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => 2))) - assert_df_subtree_query_complexity(df, 6) + assert_df_subtree_query_complexity( + df, + { + ComplexityStat.COLUMN.value: 1, + ComplexityStat.LITERAL.value: 3, + ComplexityStat.LOW_IMPACT.value: 2, + ComplexityStat.ORDER_BY.value: 1, + ComplexityStat.WINDOW.value: 1, + }, + ) def test_generator_table_function(session: Session): df1 = session.generator( seq1(1).as_("seq"), uniform(1, 10, 2).as_("uniform"), rowcount=150 ) - assert_df_subtree_query_complexity(df1, 5) + assert_df_subtree_query_complexity( + df1, + { + ComplexityStat.COLUMN.value: 2, + ComplexityStat.FUNCTION.value: 1, + ComplexityStat.LITERAL.value: 1, + }, + ) df2 = df1.order_by("seq") # adds SELECT * from () ORDER BY seq ASC NULLS FIRST assert_df_subtree_query_complexity( - df2, df1._select_statement.subtree_query_complexity + 3 + df2, + get_cumulative_complexity_stat(df1) + + Counter({ComplexityStat.ORDER_BY.value: 1, ComplexityStat.COLUMN.value: 1}), ) @@ -96,15 +122,23 @@ def test_join_table_function(session: Session): "select 'James' as name, 'address1 address2 address3' as addresses" ) # SelectSQL chooses num active columns as the best estimate - # assert_df_subtree_query_complexity(df1, 2) + assert_df_subtree_query_complexity(df1, {ComplexityStat.COLUMN.value: 2}) split_to_table = table_function("split_to_table") df2 = df1.select(split_to_table(col("addresses"), lit(" "))) - # +3 SELECT "SEQ", "INDEX", "VALUE" FROM ( - # +3 SELECT T_RIGHT."SEQ", T_RIGHT."INDEX", T_RIGHT."VALUE" FROM - # +2 (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT - # +3 JOIN TABLE (split_to_table("ADDRESSES", ' ') ) AS T_RIGHT) - assert_df_subtree_query_complexity(df2, 11) + # SELECT "SEQ", "INDEX", "VALUE" FROM ( + # SELECT T_RIGHT."SEQ", T_RIGHT."INDEX", T_RIGHT."VALUE" FROM + # (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT + # JOIN TABLE (split_to_table("ADDRESSES", ' ') ) AS T_RIGHT) + assert_df_subtree_query_complexity( + df2, + { + ComplexityStat.COLUMN.value: 9, + ComplexityStat.JOIN.value: 1, + ComplexityStat.FUNCTION.value: 1, + ComplexityStat.LITERAL.value: 1, + }, + ) @pytest.mark.parametrize( @@ -123,21 +157,56 @@ def test_set_operators(session: Session, sample_table: str, set_operator: str): df = df1.intersect(df2) # ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) set_operator ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) - assert_df_subtree_query_complexity(df, 3) + assert_df_subtree_query_complexity( + df, {ComplexityStat.COLUMN.value: 2, ComplexityStat.SET_OPERATION.value: 1} + ) def test_agg(session: Session, sample_table: str): df = session.table(sample_table) df1 = df.agg(avg("a")) df2 = df.agg(avg("a") + 1) - df3 = df.agg(avg("a"), avg("b" + lit(1)).as_("avg_b")) + df3 = df.agg(avg("a"), avg(col("b") + lit(1)).as_("avg_b")) + df4 = df.group_by(["a", "b"]).agg(avg("c")) # SELECT avg("A") AS "AVG(A)" FROM ( SELECT * FROM sample_table) LIMIT 1 - assert_df_subtree_query_complexity(df1, 3) + assert_df_subtree_query_complexity( + df1, + { + ComplexityStat.COLUMN.value: 3, + ComplexityStat.LOW_IMPACT.value: 1, + ComplexityStat.FUNCTION.value: 1, + }, + ) # SELECT (avg("A") + 1 :: INT) AS "ADD(AVG(A), LITERAL())" FROM ( SELECT * FROM sample_table) LIMIT 1 - assert_df_subtree_query_complexity(df2, 5) - # SELECT avg("A") AS "AVG(A)", avg(('b' + 1 :: INT)) AS "AVG_B" FROM ( SELECT * FROM sample_table) LIMIT 1 - assert_df_subtree_query_complexity(df3, 6) + assert_df_subtree_query_complexity( + df2, + { + ComplexityStat.COLUMN.value: 3, + ComplexityStat.LOW_IMPACT.value: 2, + ComplexityStat.FUNCTION.value: 1, + ComplexityStat.LITERAL.value: 1, + }, + ) + # SELECT avg("A") AS "AVG(A)", avg(("B" + 1 :: INT)) AS "AVG_B" FROM ( SELECT * FROM sample_table) LIMIT 1 + assert_df_subtree_query_complexity( + df3, + { + ComplexityStat.COLUMN.value: 5, + ComplexityStat.LOW_IMPACT.value: 2, + ComplexityStat.FUNCTION.value: 2, + ComplexityStat.LITERAL.value: 1, + }, + ) + # SELECT "A", "B", avg("C") AS "AVG(C)" FROM ( SELECT * FROM SNOWPARK_TEMP_TABLE_EV1NO4AID6) GROUP BY "A", "B" + assert_df_subtree_query_complexity( + df4, + { + ComplexityStat.COLUMN.value: 7, + ComplexityStat.GROUP_BY.value: 1, + ComplexityStat.FUNCTION.value: 1, + }, + ) def test_window_function(session: Session): @@ -150,41 +219,90 @@ def test_window_function(session: Session): df = session.create_dataframe( [(1, "1"), (2, "2"), (1, "3"), (2, "4")], schema=["key", "value"] ) + table_name = Utils.random_table_name() + try: + df.write.save_as_table(table_name, table_type="temp", mode="overwrite") - df1 = df.select(avg("value").over(window1).as_("window1")) - # SELECT avg("VALUE") OVER (PARTITION BY "VALUE" ORDER BY "KEY" ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING ) AS "WINDOW1" FROM ( base_df) - assert_df_subtree_query_complexity(df1, get_subtree_query_complexity(df) + 9) + df1 = session.table(table_name).select( + avg("value").over(window1).as_("window1") + ) + # SELECT avg("VALUE") OVER (PARTITION BY "VALUE" ORDER BY "KEY" ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING ) AS "WINDOW1" FROM table_name + assert_df_subtree_query_complexity( + df1, + Counter( + { + ComplexityStat.PARTITION_BY.value: 1, + ComplexityStat.ORDER_BY.value: 1, + ComplexityStat.WINDOW.value: 1, + ComplexityStat.FUNCTION.value: 1, + ComplexityStat.COLUMN.value: 5, + ComplexityStat.LITERAL.value: 1, + ComplexityStat.LOW_IMPACT.value: 2, + } + ), + ) - # SELECT avg("VALUE") OVER ( ORDER BY "KEY" DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS "WINDOW2" FROM ( base df) - df2 = df1.select(avg("value").over(window2).as_("window2")) - assert_df_subtree_query_complexity(df2, get_subtree_query_complexity(df1) + 9) + # SELECT avg("VALUE") OVER ( ORDER BY "KEY" DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS "WINDOW2" FROM ( + # SELECT avg("VALUE") OVER (PARTITION BY "VALUE" ORDER BY "KEY" ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING ) AS "WINDOW1" FROM table_name) + df2 = df1.select(avg("value").over(window2).as_("window2")) + assert_df_subtree_query_complexity( + df2, + get_cumulative_complexity_stat(df1) + + Counter( + { + ComplexityStat.ORDER_BY.value: 1, + ComplexityStat.WINDOW.value: 1, + ComplexityStat.FUNCTION.value: 1, + ComplexityStat.COLUMN.value: 3, + ComplexityStat.LOW_IMPACT.value: 3, + } + ), + ) + finally: + Utils.drop_table(session, table_name) def test_join_statement(session: Session, sample_table: str): # SELECT * FROM table df1 = session.table(sample_table) - assert_df_subtree_query_complexity(df1, 1) + assert_df_subtree_query_complexity(df1, {ComplexityStat.COLUMN.value: 1}) # SELECT A, B, E FROM (SELECT $1 AS "A", $2 AS "B", $3 AS "E" FROM VALUES (1 :: INT, 2 :: INT, 5 :: INT), (3 :: INT, 4 :: INT, 9 :: INT)) df2 = session.create_dataframe([[1, 2, 5], [3, 4, 9]], schema=["a", "b", "e"]) - assert_df_subtree_query_complexity(df2, 12) + assert_df_subtree_query_complexity( + df2, {ComplexityStat.COLUMN.value: 6, ComplexityStat.LITERAL.value: 6} + ) df3 = df1.join(df2) - # +3 SELECT * FROM ((ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT) - # +4 ch1 = SELECT "A" AS "l_p8bm_A", "B" AS "l_p8bm_B", "C" AS "C", "D" AS "D" FROM (df1) - # +0 ch2 = SELECT "A" AS "r_2og4_A", "B" AS "r_2og4_B", "E" AS "E" FROM (SELECT $1 AS "A", $2 AS "B", $3 AS "E" FROM VALUES ()) - # ch2 is a re-write flattened version of df2 with aliases - assert_df_subtree_query_complexity(df3, 13 + 7) + # SELECT * FROM (( SELECT "A" AS "l_fkl0_A", "B" AS "l_fkl0_B", "C" AS "C", "D" AS "D" FROM sample_table) AS SNOWPARK_LEFT + # INNER JOIN ( + # SELECT "A" AS "r_co85_A", "B" AS "r_co85_B", "E" AS "E" FROM ( + # SELECT $1 AS "A", $2 AS "B", $3 AS "E" FROM VALUES (1 :: INT, 2 :: INT, 5 :: INT), (3 :: INT, 4 :: INT, 9 :: INT))) AS SNOWPARK_RIGHT) + assert_df_subtree_query_complexity( + df3, + { + ComplexityStat.COLUMN.value: 18, + ComplexityStat.LITERAL.value: 6, + ComplexityStat.JOIN.value: 1, + }, + ) df4 = df1.join(df2, on=((df1["a"] == df2["a"]) & (df1["b"] == df2["b"]))) # SELECT * FROM ((ch1) AS SNOWPARK_LEFT INNER JOIN ( ch2) AS SNOWPARK_RIGHT ON (("l_k7b8_A" = "r_e09m_A") AND ("l_k7b8_B" = "r_e09m_B"))) - assert_df_subtree_query_complexity(df4, get_subtree_query_complexity(df3) + 7) + assert_df_subtree_query_complexity( + df4, + get_cumulative_complexity_stat(df3) + + Counter({ComplexityStat.COLUMN.value: 4, ComplexityStat.LOW_IMPACT.value: 3}), + ) df5 = df1.join(df2, using_columns=["a", "b"]) # SELECT * FROM ( (ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT USING (a, b)) - assert_df_subtree_query_complexity(df5, get_subtree_query_complexity(df3) + 3) + assert_df_subtree_query_complexity( + df5, + get_cumulative_complexity_stat(df3) + Counter({ComplexityStat.COLUMN.value: 2}), + ) -def test_pivot_and_unpivot(session: Session): +def test_pivot(session: Session): try: session.sql( """create or replace temp table monthly_sales(empid int, amount int, month text) @@ -202,7 +320,15 @@ def test_pivot_and_unpivot(session: Session): session.table("monthly_sales").pivot("month", ["JAN", "FEB"]).sum("amount") ) # SELECT * FROM ( SELECT * FROM monthly_sales) PIVOT (sum("AMOUNT") FOR "MONTH" IN ('JAN', 'FEB')) - assert_df_subtree_query_complexity(df_pivot1, 8) + assert_df_subtree_query_complexity( + df_pivot1, + { + ComplexityStat.PIVOT.value: 1, + ComplexityStat.COLUMN.value: 4, + ComplexityStat.LITERAL.value: 2, + ComplexityStat.FUNCTION.value: 1, + }, + ) df_pivot2 = ( session.table("monthly_sales") @@ -210,21 +336,37 @@ def test_pivot_and_unpivot(session: Session): .sum("amount") ) # SELECT * FROM ( SELECT * FROM monthly_sales) PIVOT (sum("AMOUNT") FOR "MONTH" IN ('JAN', 'FEB', 'MARCH')) - assert_df_subtree_query_complexity(df_pivot2, 9) + assert_df_subtree_query_complexity( + df_pivot2, + { + ComplexityStat.PIVOT.value: 1, + ComplexityStat.COLUMN.value: 4, + ComplexityStat.LITERAL.value: 3, + ComplexityStat.FUNCTION.value: 1, + }, + ) + finally: + Utils.drop_table(session, "monthly_sales") + +def test_unpivot(session: Session): + try: session.sql( """create or replace temp table sales_for_month(empid int, dept varchar, jan int, feb int) as select * from values (1, 'electronics', 100, 200), (2, 'clothes', 100, 300)""" ).collect() + df_unpivot1 = session.table("sales_for_month").unpivot( "sales", "month", ["jan", "feb"] ) # SELECT * FROM ( SELECT * FROM (sales_for_month)) UNPIVOT (sales FOR month IN ("JAN", "FEB")) - assert_df_subtree_query_complexity(df_unpivot1, 7) + assert_df_subtree_query_complexity( + df_unpivot1, + {ComplexityStat.UNPIVOT.value: 1, ComplexityStat.COLUMN.value: 6}, + ) finally: - Utils.drop_table(session, "monthly_sales") Utils.drop_table(session, "sales_for_month") @@ -232,66 +374,120 @@ def test_sample(session: Session, sample_table): df = session.table(sample_table) df_sample_frac = df.sample(0.5) # SELECT * FROM ( SELECT * FROM (sample_table)) SAMPLE (50.0) - assert_df_subtree_query_complexity(df_sample_frac, 3) + assert_df_subtree_query_complexity( + df_sample_frac, + { + ComplexityStat.SAMPLE.value: 1, + ComplexityStat.LITERAL.value: 1, + ComplexityStat.COLUMN.value: 2, + }, + ) df_sample_rows = df.sample(n=1) # SELECT * FROM ( SELECT * FROM (sample_table)) SAMPLE (1 ROWS) - assert_df_subtree_query_complexity(df_sample_rows, 4) - + assert_df_subtree_query_complexity( + df_sample_rows, + { + ComplexityStat.SAMPLE.value: 1, + ComplexityStat.LITERAL.value: 1, + ComplexityStat.COLUMN.value: 2, + }, + ) -@pytest.mark.parametrize("source_from_table", [True, False]) -def test_select_statement_subtree_complexity_estimate( - session: Session, sample_table: str, source_from_table: bool -): - if source_from_table: - df1 = session.table(sample_table) - else: - df1 = session.create_dataframe( - [[1, 2, 3, 4], [5, 6, 7, 8]], schema=["a", "b", "c", "d"] - ) - assert_df_subtree_query_complexity(df1, 1 if source_from_table else 16) +def test_select_statement_with_multiple_operations(session: Session, sample_table: str): + df1 = session.table(sample_table) # add select - # +3 for column + # SELECT "A", "B", "C" FROM sample_table + # note that column stat is 4 even though selected columns is 3. This is because we count 1 column + # from select * from sample_table which is flattened out. This is a known limitation but is okay + # since we are not off my much df2 = df1.select("a", "b", "c") - assert_df_subtree_query_complexity(df2, 4 if source_from_table else 15) + assert_df_subtree_query_complexity(df2, {ComplexityStat.COLUMN.value: 4}) - # +2 for column (1 less active column) + # 1 less active column df3 = df2.select("b", "c") - assert_df_subtree_query_complexity(df3, 3 if source_from_table else 14) + assert_df_subtree_query_complexity(df3, {ComplexityStat.COLUMN.value: 3}) # add sort - # +3 for additional ORDER BY "B" ASC NULLS FIRST + # for additional ORDER BY "B" ASC NULLS FIRST df4 = df3.sort(col("b").asc()) - assert_df_subtree_query_complexity(df4, 3 + get_subtree_query_complexity(df3)) + assert_df_subtree_query_complexity( + df4, + get_cumulative_complexity_stat(df3) + + Counter({ComplexityStat.COLUMN.value: 1, ComplexityStat.ORDER_BY.value: 1}), + ) - # +3 for additional ,"C" ASC NULLS FIRST + # for additional ,"C" ASC NULLS FIRST df5 = df4.sort(col("c").desc()) - assert_df_subtree_query_complexity(df5, 2 + get_subtree_query_complexity(df4)) + assert_df_subtree_query_complexity( + df5, + get_cumulative_complexity_stat(df4) + Counter({ComplexityStat.COLUMN.value: 1}), + ) # add filter - # +4 for WHERE ("B" > 2) + # for WHERE ("B" > 2) df6 = df5.filter(col("b") > 2) - assert_df_subtree_query_complexity(df6, 4 + get_subtree_query_complexity(df5)) + assert_df_subtree_query_complexity( + df6, + get_cumulative_complexity_stat(df5) + + Counter( + { + ComplexityStat.FILTER.value: 1, + ComplexityStat.COLUMN.value: 1, + ComplexityStat.LITERAL.value: 1, + ComplexityStat.LOW_IMPACT.value: 1, + } + ), + ) - # +4 for filter - AND ("C" > 3) + # for filter - AND ("C" > 3) df7 = df6.filter(col("c") > 3) - assert_df_subtree_query_complexity(df7, 4 + get_subtree_query_complexity(df6)) + assert_df_subtree_query_complexity( + df7, + get_cumulative_complexity_stat(df6) + + Counter( + { + ComplexityStat.COLUMN.value: 1, + ComplexityStat.LITERAL.value: 1, + ComplexityStat.LOW_IMPACT.value: 2, + } + ), + ) # add set operations - # +2 for 2 unions, 12 for sum of individual df complexity df8 = df3.union_all(df4).union_all(df5) assert_df_subtree_query_complexity( - df8, 2 + sum(get_subtree_query_complexity(df) for df in [df3, df4, df5]) + df8, + sum( + (get_cumulative_complexity_stat(df) for df in [df3, df4, df5]), + Counter({ComplexityStat.SET_OPERATION.value: 2}), + ), ) # + 2 for 2 unions, 30 for sum ob individual df complexity df9 = df8.union_all(df6).union_all(df7) assert_df_subtree_query_complexity( - df9, 2 + sum(get_subtree_query_complexity(df) for df in [df6, df7, df8]) + df9, + sum( + (get_cumulative_complexity_stat(df) for df in [df6, df7, df8]), + Counter({ComplexityStat.SET_OPERATION.value: 2}), + ), ) - # +1 for limit + # for limit df10 = df9.limit(2) - assert_df_subtree_query_complexity(df10, 1 + get_subtree_query_complexity(df9)) + assert_df_subtree_query_complexity( + df10, + get_cumulative_complexity_stat(df9) + + Counter({ComplexityStat.LOW_IMPACT.value: 1}), + ) + + # for offset + df11 = df9.limit(3, offset=1) + assert_df_subtree_query_complexity( + df11, + get_cumulative_complexity_stat(df9) + + Counter({ComplexityStat.LOW_IMPACT.value: 2}), + ) diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index 4c5737ae042..bd00cecb0bb 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -593,7 +593,14 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_estimate": 14, + "query_plan_complexity_stat": { + "filter": 1, + "low_impact": 5, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -607,7 +614,14 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_estimate": 14, + "query_plan_complexity_stat": { + "filter": 1, + "low_impact": 5, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -621,7 +635,14 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_estimate": 14, + "query_plan_complexity_stat": { + "filter": 1, + "low_impact": 5, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -635,7 +656,14 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_estimate": 14, + "query_plan_complexity_stat": { + "filter": 1, + "low_impact": 5, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -649,7 +677,14 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_estimate": 14, + "query_plan_complexity_stat": { + "filter": 1, + "low_impact": 5, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -786,7 +821,7 @@ def test_dataframe_stat_functions_api_calls(session): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_estimate": 54, + "query_plan_complexity_stat": {"group_by": 1, "column": 6, "literal": 48}, }, { "name": "DataFrameStatFunctions.crosstab", @@ -804,7 +839,7 @@ def test_dataframe_stat_functions_api_calls(session): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_estimate": 54, + "query_plan_complexity_stat": {"group_by": 1, "column": 6, "literal": 48}, } ] From fb68aa06b9b682e857afbc5061d4f618adc088a8 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 4 Jun 2024 13:41:40 -0700 Subject: [PATCH 09/37] fix type checks --- .../_internal/analyzer/binary_expression.py | 19 ++++- .../_internal/analyzer/binary_plan_node.py | 21 ++++- .../snowpark/_internal/analyzer/expression.py | 79 +++++++++++-------- .../_internal/analyzer/grouping_set.py | 23 ++++-- .../_internal/analyzer/select_statement.py | 34 +++++--- .../_internal/analyzer/snowflake_plan.py | 23 ++++-- .../_internal/analyzer/snowflake_plan_node.py | 43 +++++++--- .../_internal/analyzer/sort_expression.py | 21 ++++- .../_internal/analyzer/table_function.py | 35 +++++--- .../analyzer/table_merge_expression.py | 29 +++++-- .../_internal/analyzer/unary_expression.py | 23 ++++-- .../_internal/analyzer/unary_plan_node.py | 31 +++++--- .../_internal/analyzer/window_expression.py | 34 +++++--- 13 files changed, 292 insertions(+), 123 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 41c2fa32fcf..552fddbdc88 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,8 +2,21 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter -from typing import AbstractSet, Dict, Optional +import sys +from typing import AbstractSet, Optional + +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import ( @@ -29,7 +42,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.LOW_IMPACT.value: 1}) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index 10d0e75502b..cb1bfe2b4d9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -2,8 +2,21 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter -from typing import Dict, List, Optional +import sys +from typing import List, Optional + +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import Expression @@ -72,7 +85,7 @@ def __init__(self, left: LogicalPlan, right: LogicalPlan) -> None: class SetOperation(BinaryNode): @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # (left) operator (right) return Counter({ComplexityStat.SET_OPERATION.value: 1}) @@ -194,7 +207,7 @@ def sql(self) -> str: return self.join_type.sql @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond estimate = Counter({ComplexityStat.JOIN.value: 1}) if isinstance(self.join_type, UsingJoin) and self.join_type.using_columns: diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 86ab229c6e1..0a9dfdc92de 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -3,10 +3,23 @@ # import copy +import sys import uuid -from collections import Counter from functools import cached_property -from typing import TYPE_CHECKING, AbstractSet, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple + +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat @@ -84,11 +97,11 @@ def sql(self) -> str: return f"{self.pretty_name}({children_sql})" @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: children = self.children or [] return sum( (child.cumulative_complexity_stat for child in children), @@ -124,7 +137,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return self.plan.cumulative_complexity_stat + self.individual_complexity_stat @@ -137,7 +150,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return ( sum( (expr.cumulative_complexity_stat for expr in self.expressions), @@ -157,11 +170,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.IN.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return ( self.columns.cumulative_complexity_stat + self.individual_complexity_stat @@ -200,7 +213,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.COLUMN.value: 1}) @@ -216,14 +229,14 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: if self.expressions: return Counter({}) # if there are no expressions, we assign column value = 1 to Star return Counter({ComplexityStat.COLUMN.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return self.individual_complexity_stat + sum( (child.individual_complexity_stat for child in self.expressions), Counter({}), @@ -264,7 +277,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.COLUMN.value: 1}) @@ -291,7 +304,7 @@ def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None: self.datatype = infer_type(value) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.LITERAL.value: 1}) @@ -343,7 +356,7 @@ def __str__(self) -> str: return self.sql @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter( { ComplexityStat.LITERAL.value: 2 * len(self.values_dict), @@ -362,12 +375,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # expr LIKE pattern return Counter({ComplexityStat.LOW_IMPACT.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return ( self.expr.cumulative_complexity_stat + self.pattern.cumulative_complexity_stat @@ -385,12 +398,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # expr REG_EXP pattern return Counter({ComplexityStat.LOW_IMPACT.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return ( self.expr.cumulative_complexity_stat + self.pattern.cumulative_complexity_stat @@ -408,12 +421,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # expr COLLATE collate_spec return Counter({ComplexityStat.LOW_IMPACT.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return self.expr.cumulative_complexity_stat + self.individual_complexity_stat @@ -427,12 +440,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # the literal corresponds to the contribution from self.field return Counter({ComplexityStat.LITERAL.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: # self.expr ( self.field ) return self.expr.cumulative_complexity_stat + self.individual_complexity_stat @@ -447,12 +460,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # the literal corresponds to the contribution from self.field return Counter({ComplexityStat.LITERAL.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: # self.expr ( self.field ) return self.expr.cumulative_complexity_stat + self.individual_complexity_stat @@ -489,7 +502,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.FUNCTION.value: 1}) @@ -504,12 +517,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # expr WITHIN GROUP (ORDER BY cols) return Counter({ComplexityStat.ORDER_BY.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return ( sum( (col.cumulative_complexity_stat for col in self.order_by_cols), @@ -539,11 +552,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*exps) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.CASE_WHEN.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: estimate = self.individual_complexity_stat + sum( ( condition.cumulative_complexity_stat + value.cumulative_complexity_stat @@ -579,11 +592,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.FUNCTION.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return sum( (expr.cumulative_complexity_stat for expr in self.children), self.individual_complexity_stat, @@ -601,9 +614,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.FUNCTION.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return self.col.cumulative_complexity_stat + self.individual_complexity_stat diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 380e1ce9509..a1d9463e583 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -2,9 +2,22 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter +import sys from functools import cached_property -from typing import AbstractSet, Dict, List, Optional +from typing import AbstractSet, List, Optional + +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import ( @@ -23,7 +36,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.LOW_IMPACT.value: 1}) @@ -45,7 +58,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*flattened_args) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return sum( ( sum((expr.cumulative_complexity_stat for expr in arg), Counter()) @@ -55,5 +68,5 @@ def cumulative_complexity_stat(self) -> Dict[str, int]: ) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.LOW_IMPACT.value: 1}) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 5707369c4e2..43277107818 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -2,8 +2,9 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import sys from abc import ABC, abstractmethod -from collections import Counter, UserDict, defaultdict +from collections import UserDict, defaultdict from copy import copy, deepcopy from enum import Enum from typing import ( @@ -19,6 +20,19 @@ Union, ) +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter + import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.cte_utils import encode_id @@ -201,7 +215,7 @@ def __init__( str, Dict[str, str] ] = defaultdict(dict) self._api_calls = api_calls.copy() if api_calls is not None else None - self._cumulative_complexity_stat: Optional[Dict[str, int]] = None + self._cumulative_complexity_stat: Optional[Counter[str]] = None def __eq__(self, other: "Selectable") -> bool: if self._id is not None and other._id is not None: @@ -293,7 +307,7 @@ def num_duplicate_nodes(self) -> int: return self.snowflake_plan.num_duplicate_nodes @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: """This is the query complexity estimate added by this Selectable node to the overall query plan. For default case, it is the number of active columns. Specific cases are handled in child classes with additional @@ -303,10 +317,10 @@ def individual_complexity_stat(self) -> Dict[str, int]: return Counter( {ComplexityStat.COLUMN.value: len(self.column_states.active_columns)} ) - return self.snowflake_plan.source_plan.individual_complexity_stat + return self.snowflake_plan.individual_complexity_stat @property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: """This is sum of individual query complexity estimates for all nodes within a query plan subtree. """ @@ -318,7 +332,7 @@ def cumulative_complexity_stat(self) -> Dict[str, int]: return self._cumulative_complexity_stat @cumulative_complexity_stat.setter - def cumulative_complexity_stat(self, value: Dict[str, int]): + def cumulative_complexity_stat(self, value: Counter[str]): self._cumulative_complexity_stat = value @property @@ -380,7 +394,7 @@ def schema_query(self) -> str: return self.sql_query @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM entity return Counter({ComplexityStat.COLUMN.value: 1}) @@ -440,7 +454,7 @@ def schema_query(self) -> str: return self._schema_query @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: if self.pre_actions: # having pre-actions implies we have a non-select query followed by a # SELECT * FROM table(result_scan(query_id)) statement @@ -707,7 +721,7 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return [self.from_] @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: estimate = Counter() # projection component estimate += ( @@ -1127,7 +1141,7 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return self._nodes @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # we add #set_operands - 1 additional operators in sql query return Counter({ComplexityStat.SET_OPERATION.value: len(self.set_operands) - 1}) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index ba5ada58054..8d9c042f716 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -6,7 +6,7 @@ import re import sys import uuid -from collections import Counter, defaultdict +from collections import defaultdict from functools import cached_property from typing import ( TYPE_CHECKING, @@ -21,6 +21,19 @@ Union, ) +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter + from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, TableFunctionRelation, @@ -232,7 +245,7 @@ def __init__( self.placeholder_query = placeholder_query # encode an id for CTE optimization self._id = encode_id(queries[-1].sql, queries[-1].params) - self._cumulative_complexity_stat: Optional[Dict[str, int]] = None + self._cumulative_complexity_stat: Optional[Counter[str]] = None def __eq__(self, other: "SnowflakePlan") -> bool: if self._id is not None and other._id is not None: @@ -352,13 +365,13 @@ def num_duplicate_nodes(self) -> int: return len(find_duplicate_subtrees(self)) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: if self.source_plan: return self.source_plan.individual_complexity_stat return Counter() @property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: if self._cumulative_complexity_stat is None: estimate = self.individual_complexity_stat for node in self.children_plan_nodes: @@ -367,7 +380,7 @@ def cumulative_complexity_stat(self) -> Dict[str, int]: return self._cumulative_complexity_stat @cumulative_complexity_stat.setter - def cumulative_complexity_stat(self, value: Dict[str, int]): + def cumulative_complexity_stat(self, value: Counter[str]): self._cumulative_complexity_stat = value def __copy__(self) -> "SnowflakePlan": diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 2038f12a169..a9fd3f375f4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -4,7 +4,6 @@ # import sys -from collections import Counter from enum import Enum from typing import Any, Dict, List, Optional @@ -13,26 +12,36 @@ from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType -# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable -# Python 3.9 can use both +# Python 3.8: needs to use typing.Iterable because collections.abc.Iterable is not subscriptable +# needs to create new Counter class from collections.Counter so it can pass type check +# Python 3.9: can use both and type check support is added in collections.Counter from 3.9+ # Python 3.10 needs to use collections.abc.Iterable because typing.Iterable is removed if sys.version_info <= (3, 9): + import collections + import typing from typing import Iterable + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + else: + from collections import Counter from collections.abc import Iterable class LogicalPlan: def __init__(self) -> None: self.children = [] - self._cumulative_complexity_stat: Optional[Dict[str, int]] = None + self._cumulative_complexity_stat: Optional[Counter[str]] = None @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter() @property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: if self._cumulative_complexity_stat is None: estimate = self.individual_complexity_stat for node in self.children: @@ -42,7 +51,7 @@ def cumulative_complexity_stat(self) -> Dict[str, int]: return self._cumulative_complexity_stat @cumulative_complexity_stat.setter - def cumulative_complexity_stat(self, value: Dict[str, int]): + def cumulative_complexity_stat(self, value: Counter[str]): self._cumulative_complexity_stat = value @@ -61,7 +70,7 @@ def __init__(self, start: int, end: int, step: int, num_slices: int = 1) -> None self.num_slices = num_slices @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (step) + (start) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => count))) return Counter( { @@ -80,7 +89,7 @@ def __init__(self, name: str) -> None: self.name = name @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM name return Counter({ComplexityStat.COLUMN.value: 1}) @@ -98,7 +107,7 @@ def __init__( self.schema_query = schema_query @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) # TODO: use ARRAY_BIND_THRESHOLD return Counter( @@ -139,13 +148,21 @@ def __init__( self.comment = comment @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (child) estimate = Counter( {ComplexityStat.LOW_IMPACT.value: 1, ComplexityStat.COLUMN.value: 1} ) estimate += ( - sum(expr.cumulative_complexity_stat for expr in self.clustering_exprs) + Counter({ComplexityStat.COLUMN.value: len(self.column_names)}) + if self.column_names + else Counter() + ) + estimate += ( + sum( + (expr.cumulative_complexity_stat for expr in self.clustering_exprs), + Counter(), + ) if self.clustering_exprs else Counter() ) @@ -163,7 +180,7 @@ def __init__( self.children.append(child) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # for limit and offset return ( Counter({ComplexityStat.LOW_IMPACT.value: 2}) diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index ef34a845e86..3875e855620 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,15 +2,28 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter +import sys from functools import cached_property -from typing import AbstractSet, Dict, Optional, Type +from typing import AbstractSet, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter + class NullOrdering: sql: str @@ -59,9 +72,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter() @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return self.child.cumulative_complexity_stat + self.individual_complexity_stat diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 719912c2dfd..34efe0b1e90 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -3,7 +3,6 @@ # import sys -from collections import Counter from functools import cached_property from typing import Dict, List, Optional @@ -12,12 +11,22 @@ from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder -# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable -# Python 3.9 can use both +# Python 3.8: needs to use typing.Iterable because collections.abc.Iterable is not subscriptable +# needs to create new Counter class from collections.Counter so it can pass type check +# Python 3.9: can use both and type check support is added in collections.Counter from 3.9+ # Python 3.10 needs to use collections.abc.Iterable because typing.Iterable is removed if sys.version_info <= (3, 9): + import collections + import typing from typing import Iterable + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + else: + from collections import Counter from collections.abc import Iterable @@ -34,7 +43,7 @@ def __init__( self.order_spec = order_spec @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: if not self.over: return Counter() estimate = Counter({ComplexityStat.WINDOW.value: 1}) @@ -72,11 +81,11 @@ def __init__( self.api_call_source = api_call_source @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.FUNCTION.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return ( self.partition_spec.cumulative_complexity_stat + self.individual_complexity_stat @@ -97,7 +106,7 @@ def __init__( self.mode = mode @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: return self.individual_complexity_stat + self.input.cumulative_complexity_stat @@ -112,7 +121,7 @@ def __init__( self.args = args @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: estimate = sum( (arg.cumulative_complexity_stat for arg in self.args), self.individual_complexity_stat, @@ -136,7 +145,7 @@ def __init__( self.args = args @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: estimate = sum( (arg.cumulative_complexity_stat for arg in self.args.values()), self.individual_complexity_stat, @@ -156,7 +165,7 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: self.operators = operators @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: estimate = sum( (arg.cumulative_complexity_stat for arg in self.args.values()), self.individual_complexity_stat, @@ -176,7 +185,7 @@ def __init__(self, table_function: TableFunctionExpression) -> None: self.table_function = table_function @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM table_function return self.table_function.cumulative_complexity_stat @@ -196,7 +205,7 @@ def __init__( self.right_cols = right_cols if right_cols is not None else ["*"] @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT left_cols, right_cols FROM child as left_alias JOIN table(func(...)) as right_alias return ( Counter( @@ -219,7 +228,7 @@ def __init__( self.table_function = table_function @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (child), LATERAL table_func_expression return ( Counter({ComplexityStat.COLUMN.value: 1}) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index aafd07d5ac7..48eab184274 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter +import sys from functools import cached_property from typing import Dict, List, Optional @@ -13,6 +13,19 @@ SnowflakePlan, ) +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter + class MergeExpression(Expression): def __init__(self, condition: Optional[Expression]) -> None: @@ -20,11 +33,11 @@ def __init__(self, condition: Optional[Expression]) -> None: self.condition = condition @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.LOW_IMPACT.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: # WHEN MATCHED [AND condition] THEN DEL estimate = self.individual_complexity_stat estimate += ( @@ -41,7 +54,7 @@ def __init__( self.assignments = assignments @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) estimate = self.individual_complexity_stat estimate += ( @@ -74,7 +87,7 @@ def __init__( self.values = values @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) estimate = self.individual_complexity_stat estimate += ( @@ -105,7 +118,7 @@ def __init__( self.children = [source_data] if source_data else [] @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # UPDATE table_name SET COMMA.join(k, v in assignments) [source_data] [WHERE condition] estimate = sum( ( @@ -134,7 +147,7 @@ def __init__( self.children = [source_data] if source_data else [] @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # DELETE FROM table_name [USING source_data] [WHERE condition] return ( self.condition.cumulative_complexity_stat if self.condition else Counter() @@ -157,7 +170,7 @@ def __init__( self.children = [source] if source else [] @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # MERGE INTO table_name USING (source) ON join_expr clauses return self.join_expr.cumulative_complexity_stat + sum( (clause.cumulative_complexity_stat for clause in self.clauses), Counter() diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e5082ca57ae..da86a0b9605 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,8 +2,8 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter -from typing import AbstractSet, Dict, Optional +import sys +from typing import AbstractSet, Optional from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import ( @@ -13,6 +13,19 @@ ) from snowflake.snowpark.types import DataType +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter + class UnaryExpression(Expression): sql_operator: str @@ -36,7 +49,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.LOW_IMPACT.value: 1}) @@ -87,7 +100,7 @@ def __str__(self): return f"{self.child} {self.sql_operator} {self.name}" @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # child AS name return Counter({ComplexityStat.COLUMN.value: 1}) @@ -101,5 +114,5 @@ def __init__(self, child: Expression) -> None: self.name = child.sql @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter() diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 542f58801c9..90b58506cfd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter +import sys from typing import Dict, List, Optional, Union from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat @@ -14,6 +14,19 @@ from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter + class UnaryNode(LogicalPlan): def __init__(self, child: LogicalPlan) -> None: @@ -36,7 +49,7 @@ def __init__( self.seed = seed @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (child) SAMPLE (probability) -- if probability is provided # SELECT * FROM (child) SAMPLE (row_count ROWS) -- if not probability but row count is provided return Counter( @@ -54,7 +67,7 @@ def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None: self.order = order @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # child ORDER BY COMMA.join(order) return Counter({ComplexityStat.ORDER_BY.value: 1}) + sum( (col.cumulative_complexity_stat for col in self.order), Counter() @@ -73,7 +86,7 @@ def __init__( self.aggregate_expressions = aggregate_expressions @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: estimate = Counter() if self.grouping_expressions: # GROUP BY grouping_exprs @@ -115,7 +128,7 @@ def __init__( self.default_on_null = default_on_null @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: estimate = Counter() # child estimate adjustment if grouping cols if self.grouping_columns and self.aggregates and self.aggregates[0].children: @@ -164,7 +177,7 @@ def __init__( self.column_list = column_list @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) estimate = Counter( {ComplexityStat.UNPIVOT.value: 1, ComplexityStat.COLUMN.value: 3} @@ -185,7 +198,7 @@ def __init__( self.column_map = column_map @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # SELECT * RENAME (before AS after, ...) FROM child return Counter( { @@ -201,7 +214,7 @@ def __init__(self, condition: Expression, child: LogicalPlan) -> None: self.condition = condition @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # child WHERE condition return ( Counter({ComplexityStat.FILTER.value: 1}) @@ -215,7 +228,7 @@ def __init__(self, project_list: List[NamedExpression], child: LogicalPlan) -> N self.project_list = project_list @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: if not self.project_list: return Counter({ComplexityStat.COLUMN.value: 1}) diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 5729b17ea9b..c822ebcdb2e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -2,9 +2,9 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter +import sys from functools import cached_property -from typing import AbstractSet, Dict, List, Optional +from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat from snowflake.snowpark._internal.analyzer.expression import ( @@ -13,6 +13,18 @@ ) from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder +if sys.version_info <= (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter + class SpecialFrameBoundary(Expression): sql: str @@ -21,7 +33,7 @@ def __init__(self) -> None: super().__init__() @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.LOW_IMPACT.value: 1}) @@ -71,11 +83,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.LOW_IMPACT.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: # frame_type BETWEEN lower AND upper return ( self.individual_complexity_stat @@ -102,7 +114,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: ) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: estimate = Counter() estimate += ( Counter({ComplexityStat.PARTITION_BY.value: 1}) @@ -117,7 +129,7 @@ def individual_complexity_stat(self) -> Dict[str, int]: return estimate @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: # partition_spec order_by_spec frame_spec return ( self.individual_complexity_stat @@ -144,11 +156,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: return Counter({ComplexityStat.WINDOW.value: 1}) @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: # window_function OVER ( window_spec ) return ( self.window_function.cumulative_complexity_stat @@ -177,7 +189,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) @property - def individual_complexity_stat(self) -> Dict[str, int]: + def individual_complexity_stat(self) -> Counter[str]: # for func_name estimate = Counter({ComplexityStat.FUNCTION.value: 1}) # for offset @@ -193,7 +205,7 @@ def individual_complexity_stat(self) -> Dict[str, int]: return estimate @cached_property - def cumulative_complexity_stat(self) -> Dict[str, int]: + def cumulative_complexity_stat(self) -> Counter[str]: # func_name (expr [, offset] [, default]) [IGNORE NULLS] estimate = ( self.individual_complexity_stat + self.expr.cumulative_complexity_stat From 8da7ffd5fcfacc5bbeac56b2e6db1f76fb9f92a1 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 4 Jun 2024 14:49:13 -0700 Subject: [PATCH 10/37] fix async job test --- tests/integ/scala/test_async_job_suite.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/integ/scala/test_async_job_suite.py b/tests/integ/scala/test_async_job_suite.py index e78f7f8ae69..d06ebc07375 100644 --- a/tests/integ/scala/test_async_job_suite.py +++ b/tests/integ/scala/test_async_job_suite.py @@ -8,6 +8,8 @@ import pytest +from snowflake.snowpark.session import Session + try: import pandas as pd from pandas.testing import assert_frame_equal @@ -23,7 +25,7 @@ random_name_for_temp_object, ) from snowflake.snowpark.exceptions import SnowparkSQLException -from snowflake.snowpark.functions import col, when_matched, when_not_matched +from snowflake.snowpark.functions import col, sproc, when_matched, when_not_matched from snowflake.snowpark.table import DeleteResult, MergeResult, UpdateResult from snowflake.snowpark.types import ( DoubleType, @@ -349,13 +351,19 @@ def test_async_batch_insert(session): reason="TODO(SNOW-932722): Cancel query is not allowed in stored proc", ) def test_async_is_running_and_cancel(session): - async_job = session.sql("select SYSTEM$WAIT(3)").collect_nowait() + def wait(_: Session, sec: int) -> str: + sleep(sec) + return "success" + + sproc(wait, name="wait_sproc", packages=[]) + + async_job = session.sql("call wait_sproc(3)").collect_nowait() while not async_job.is_done(): sleep(1.0) assert async_job.is_done() # set 20s to avoid flakiness - async_job2 = session.sql("select SYSTEM$WAIT(20)").collect_nowait() + async_job2 = session.sql("call wait_sproc(20)").collect_nowait() assert not async_job2.is_done() async_job2.cancel() start = time() From 8474b5e55c5a9b5eeb8672cf201b038444217aea Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 4 Jun 2024 15:24:44 -0700 Subject: [PATCH 11/37] fix async job test --- tests/integ/scala/test_async_job_suite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/scala/test_async_job_suite.py b/tests/integ/scala/test_async_job_suite.py index d06ebc07375..afcd705b2d8 100644 --- a/tests/integ/scala/test_async_job_suite.py +++ b/tests/integ/scala/test_async_job_suite.py @@ -355,7 +355,7 @@ def wait(_: Session, sec: int) -> str: sleep(sec) return "success" - sproc(wait, name="wait_sproc", packages=[]) + sproc(wait, name="wait_sproc", packages=["snowflake-snowpark-python"]) async_job = session.sql("call wait_sproc(3)").collect_nowait() while not async_job.is_done(): From f72ce8c1f1e5c1bae14681f54476ff052451d56f Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 4 Jun 2024 20:12:32 -0700 Subject: [PATCH 12/37] remove change added in error --- src/snowflake/snowpark/_internal/server_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 88c286055b0..b1c592be7fb 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -656,7 +656,7 @@ def get_result_and_metadata( def get_result_query_id(self, plan: SnowflakePlan, **kwargs) -> str: # get the iterator such that the data is not fetched - result_set, _ = self.get_result_set(plan, ignore_results=True, **kwargs) + result_set, _ = self.get_result_set(plan, iter=True, **kwargs) return result_set["sfqid"] @_Decorator.wrap_exception From 310fbd23334716f65112012db14064270661f584 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 4 Jun 2024 20:18:25 -0700 Subject: [PATCH 13/37] remove change added in error --- src/snowflake/snowpark/_internal/server_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index b1c592be7fb..6eb10d56365 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -656,7 +656,7 @@ def get_result_and_metadata( def get_result_query_id(self, plan: SnowflakePlan, **kwargs) -> str: # get the iterator such that the data is not fetched - result_set, _ = self.get_result_set(plan, iter=True, **kwargs) + result_set, _ = self.get_result_set(plan, to_iter=True, **kwargs) return result_set["sfqid"] @_Decorator.wrap_exception From e79f2779f3ccb1cf2ec3d6f5ff889683e0441a3e Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 4 Jun 2024 20:22:25 -0700 Subject: [PATCH 14/37] add description on async fix --- tests/integ/scala/test_async_job_suite.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/integ/scala/test_async_job_suite.py b/tests/integ/scala/test_async_job_suite.py index afcd705b2d8..6e01cb7e00a 100644 --- a/tests/integ/scala/test_async_job_suite.py +++ b/tests/integ/scala/test_async_job_suite.py @@ -351,12 +351,14 @@ def test_async_batch_insert(session): reason="TODO(SNOW-932722): Cancel query is not allowed in stored proc", ) def test_async_is_running_and_cancel(session): + # creating a sproc here because describe query on SYSTEM$WAIT() + # triggers the wait and the async job fails because we don't hit + # the correct time boundaries + @sproc(name="wait_sproc", packages=["snowflake-snowpark-python"]) def wait(_: Session, sec: int) -> str: sleep(sec) return "success" - sproc(wait, name="wait_sproc", packages=["snowflake-snowpark-python"]) - async_job = session.sql("call wait_sproc(3)").collect_nowait() while not async_job.is_done(): sleep(1.0) From d5451ce31cb147ba9d503f04875ae9ada7b2840c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 07:24:51 -0700 Subject: [PATCH 15/37] move Counter typing into complexity stat --- .../_internal/analyzer/binary_expression.py | 19 ++++--------------- .../_internal/analyzer/binary_plan_node.py | 19 ++++--------------- .../_internal/analyzer/complexity_stat.py | 14 ++++++++++++++ .../snowpark/_internal/analyzer/expression.py | 19 ++++--------------- .../_internal/analyzer/grouping_set.py | 19 ++++--------------- .../_internal/analyzer/select_statement.py | 19 ++++--------------- .../_internal/analyzer/snowflake_plan.py | 14 +------------- .../_internal/analyzer/snowflake_plan_node.py | 19 ++++++------------- .../_internal/analyzer/sort_expression.py | 15 +-------------- .../_internal/analyzer/table_function.py | 19 ++++++------------- .../analyzer/table_merge_expression.py | 19 ++++--------------- .../_internal/analyzer/unary_expression.py | 19 ++++--------------- .../_internal/analyzer/unary_plan_node.py | 19 ++++--------------- .../_internal/analyzer/window_expression.py | 18 ++++-------------- 14 files changed, 64 insertions(+), 187 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 552fddbdc88..d1041670fdb 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,23 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import sys from typing import AbstractSet, Optional -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index cb1bfe2b4d9..0e5aeec1a0b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -2,23 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import sys from typing import List, Optional -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages diff --git a/src/snowflake/snowpark/_internal/analyzer/complexity_stat.py b/src/snowflake/snowpark/_internal/analyzer/complexity_stat.py index 975506e7c2a..8af1756dd9a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/complexity_stat.py +++ b/src/snowflake/snowpark/_internal/analyzer/complexity_stat.py @@ -2,8 +2,22 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import sys from enum import Enum +# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ +if sys.version_info < (3, 9): + import collections + import typing + + KT = typing.TypeVar("KT") + + class Counter(collections.Counter, typing.Counter[KT]): + pass + +else: + from collections import Counter # noqa + class ComplexityStat(Enum): FILTER = "filter" diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 0a9dfdc92de..11fad56e020 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -3,26 +3,15 @@ # import copy -import sys import uuid from functools import cached_property from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) if TYPE_CHECKING: from snowflake.snowpark._internal.analyzer.snowflake_plan import ( diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index a1d9463e583..7a5d2f2b316 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -2,24 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import sys from functools import cached_property from typing import AbstractSet, List, Optional -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 43277107818..fa036e9cf5c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -20,21 +20,11 @@ Union, ) -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.cte_utils import encode_id from snowflake.snowpark._internal.analyzer.table_function import ( TableFunctionExpression, @@ -50,7 +40,6 @@ class Counter(collections.Counter, typing.Counter[KT]): Analyzer, ) # pragma: no cover -import sys from snowflake.snowpark._internal.analyzer import analyzer_utils from snowflake.snowpark._internal.analyzer.analyzer_utils import ( diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 8d9c042f716..5d9c5fad144 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -21,19 +21,7 @@ Union, ) -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - +from snowflake.snowpark._internal.analyzer.sort_expression import Counter from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, TableFunctionRelation, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index a9fd3f375f4..89783c0e16e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -7,27 +7,20 @@ from enum import Enum from typing import Any, Dict, List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType -# Python 3.8: needs to use typing.Iterable because collections.abc.Iterable is not subscriptable -# needs to create new Counter class from collections.Counter so it can pass type check -# Python 3.9: can use both and type check support is added in collections.Counter from 3.9+ +# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable +# Python 3.9 can use both # Python 3.10 needs to use collections.abc.Iterable because typing.Iterable is removed if sys.version_info <= (3, 9): - import collections - import typing from typing import Iterable - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - else: - from collections import Counter from collections.abc import Iterable diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 3875e855620..5b21e7a65a5 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,7 +2,6 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import sys from functools import cached_property from typing import AbstractSet, Optional, Type @@ -10,19 +9,7 @@ Expression, derive_dependent_columns, ) - -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter +from snowflake.snowpark._internal.analyzer.table_merge_expression import Counter class NullOrdering: diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 34efe0b1e90..dd6f0f7ca35 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -6,27 +6,20 @@ from functools import cached_property from typing import Dict, List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder -# Python 3.8: needs to use typing.Iterable because collections.abc.Iterable is not subscriptable -# needs to create new Counter class from collections.Counter so it can pass type check -# Python 3.9: can use both and type check support is added in collections.Counter from 3.9+ +# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable +# Python 3.9 can use both # Python 3.10 needs to use collections.abc.Iterable because typing.Iterable is removed if sys.version_info <= (3, 9): - import collections - import typing from typing import Iterable - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - else: - from collections import Counter from collections.abc import Iterable diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index 48eab184274..dde0d04b49c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -2,30 +2,19 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import sys from functools import cached_property from typing import Dict, List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.snowflake_plan import ( LogicalPlan, SnowflakePlan, ) -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - class MergeExpression(Expression): def __init__(self, condition: Optional[Expression]) -> None: diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index da86a0b9605..32fb7775473 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,10 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import sys from typing import AbstractSet, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, @@ -13,19 +15,6 @@ ) from snowflake.snowpark.types import DataType -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - class UnaryExpression(Expression): sql_operator: str diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 90b58506cfd..730b0118b16 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -2,10 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import sys from typing import Dict, List, Optional, Union -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, @@ -14,19 +16,6 @@ from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - class UnaryNode(LogicalPlan): def __init__(self, child: LogicalPlan) -> None: diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index c822ebcdb2e..6501e9e5ef0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -2,29 +2,19 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import sys from functools import cached_property from typing import AbstractSet, List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.complexity_stat import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder -if sys.version_info <= (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter - class SpecialFrameBoundary(Expression): sql: str From 3b639bb84eb7f4d5ab78cfeaaf66e5a0211baf55 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 07:29:35 -0700 Subject: [PATCH 16/37] rename file --- .../snowpark/_internal/analyzer/binary_expression.py | 8 ++++---- .../snowpark/_internal/analyzer/binary_plan_node.py | 4 ++-- .../snowpark/_internal/analyzer/expression.py | 2 +- .../snowpark/_internal/analyzer/grouping_set.py | 8 ++++---- .../{complexity_stat.py => materialization_utils.py} | 0 .../snowpark/_internal/analyzer/select_statement.py | 5 ++--- .../_internal/analyzer/snowflake_plan_node.py | 4 ++-- .../snowpark/_internal/analyzer/table_function.py | 4 ++-- .../_internal/analyzer/table_merge_expression.py | 4 ++-- .../snowpark/_internal/analyzer/unary_expression.py | 8 ++++---- .../snowpark/_internal/analyzer/unary_plan_node.py | 8 ++++---- .../snowpark/_internal/analyzer/window_expression.py | 8 ++++---- tests/integ/test_materialization_suite.py | 11 ++++++----- 13 files changed, 37 insertions(+), 37 deletions(-) rename src/snowflake/snowpark/_internal/analyzer/{complexity_stat.py => materialization_utils.py} (100%) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index d1041670fdb..0d346ab9783 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -4,14 +4,14 @@ from typing import AbstractSet, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ( - ComplexityStat, - Counter, -) from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) +from snowflake.snowpark._internal.analyzer.materialization_utils import ( + ComplexityStat, + Counter, +) class BinaryExpression(Expression): diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index 0e5aeec1a0b..7b10c361941 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -4,11 +4,11 @@ from typing import List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ( +from snowflake.snowpark._internal.analyzer.expression import Expression +from snowflake.snowpark._internal.analyzer.materialization_utils import ( ComplexityStat, Counter, ) -from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 11fad56e020..889c234df74 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.complexity_stat import ( +from snowflake.snowpark._internal.analyzer.materialization_utils import ( ComplexityStat, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 7a5d2f2b316..4bf7a227ae8 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -5,14 +5,14 @@ from functools import cached_property from typing import AbstractSet, List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ( - ComplexityStat, - Counter, -) from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) +from snowflake.snowpark._internal.analyzer.materialization_utils import ( + ComplexityStat, + Counter, +) class GroupingSet(Expression): diff --git a/src/snowflake/snowpark/_internal/analyzer/complexity_stat.py b/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py similarity index 100% rename from src/snowflake/snowpark/_internal/analyzer/complexity_stat.py rename to src/snowflake/snowpark/_internal/analyzer/materialization_utils.py diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index fa036e9cf5c..5a1238f3d22 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -21,11 +21,11 @@ ) import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.complexity_stat import ( +from snowflake.snowpark._internal.analyzer.cte_utils import encode_id +from snowflake.snowpark._internal.analyzer.materialization_utils import ( ComplexityStat, Counter, ) -from snowflake.snowpark._internal.analyzer.cte_utils import encode_id from snowflake.snowpark._internal.analyzer.table_function import ( TableFunctionExpression, TableFunctionJoin, @@ -40,7 +40,6 @@ Analyzer, ) # pragma: no cover - from snowflake.snowpark._internal.analyzer import analyzer_utils from snowflake.snowpark._internal.analyzer.analyzer_utils import ( result_scan_statement, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 89783c0e16e..d3e8d01279e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -7,11 +7,11 @@ from enum import Enum from typing import Any, Dict, List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ( +from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression +from snowflake.snowpark._internal.analyzer.materialization_utils import ( ComplexityStat, Counter, ) -from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index dd6f0f7ca35..cfc20244e4b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -6,11 +6,11 @@ from functools import cached_property from typing import Dict, List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ( +from snowflake.snowpark._internal.analyzer.expression import Expression +from snowflake.snowpark._internal.analyzer.materialization_utils import ( ComplexityStat, Counter, ) -from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index dde0d04b49c..6e8111e6faf 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -5,11 +5,11 @@ from functools import cached_property from typing import Dict, List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ( +from snowflake.snowpark._internal.analyzer.expression import Expression +from snowflake.snowpark._internal.analyzer.materialization_utils import ( ComplexityStat, Counter, ) -from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.snowflake_plan import ( LogicalPlan, SnowflakePlan, diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index 32fb7775473..a5c05c4c05e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -4,15 +4,15 @@ from typing import AbstractSet, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ( - ComplexityStat, - Counter, -) from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, derive_dependent_columns, ) +from snowflake.snowpark._internal.analyzer.materialization_utils import ( + ComplexityStat, + Counter, +) from snowflake.snowpark.types import DataType diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 730b0118b16..507e011b4dd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -4,15 +4,15 @@ from typing import Dict, List, Optional, Union -from snowflake.snowpark._internal.analyzer.complexity_stat import ( - ComplexityStat, - Counter, -) from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, ScalarSubquery, ) +from snowflake.snowpark._internal.analyzer.materialization_utils import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 6501e9e5ef0..82677eec59f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -5,14 +5,14 @@ from functools import cached_property from typing import AbstractSet, List, Optional -from snowflake.snowpark._internal.analyzer.complexity_stat import ( - ComplexityStat, - Counter, -) from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) +from snowflake.snowpark._internal.analyzer.materialization_utils import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder diff --git a/tests/integ/test_materialization_suite.py b/tests/integ/test_materialization_suite.py index 1cc4679f281..ed7c0846778 100644 --- a/tests/integ/test_materialization_suite.py +++ b/tests/integ/test_materialization_suite.py @@ -2,12 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter -from typing import Dict import pytest -from snowflake.snowpark._internal.analyzer.complexity_stat import ComplexityStat +from snowflake.snowpark._internal.analyzer.materialization_utils import ( + ComplexityStat, + Counter, +) from snowflake.snowpark._internal.analyzer.select_statement import ( SET_EXCEPT, SET_INTERSECT, @@ -50,11 +51,11 @@ def sample_table(session): Utils.drop_table(session, table_name) -def get_cumulative_complexity_stat(df: DataFrame) -> Dict[str, int]: +def get_cumulative_complexity_stat(df: DataFrame) -> Counter[str]: return df._plan.cumulative_complexity_stat -def assert_df_subtree_query_complexity(df: DataFrame, estimate: Dict[str, int]): +def assert_df_subtree_query_complexity(df: DataFrame, estimate: Counter[str]): assert ( get_cumulative_complexity_stat(df) == estimate ), f"query = {df.queries['queries'][-1]}" From 909ed6700809bc27e60387024d2be76b5857e4a5 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 08:04:14 -0700 Subject: [PATCH 17/37] address feedback --- .../snowpark/_internal/analyzer/expression.py | 24 ++++++++------- .../analyzer/materialization_utils.py | 8 +++++ .../_internal/analyzer/select_statement.py | 4 +-- .../_internal/analyzer/snowflake_plan_node.py | 6 ++++ .../_internal/analyzer/unary_plan_node.py | 29 ++++++++++++------- 5 files changed, 48 insertions(+), 23 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 889c234df74..eea8846632c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -87,10 +87,16 @@ def sql(self) -> str: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({}) + """Returns the individual contribution of the expression node towards the overall + compilation complexity of the generated sql. + """ + return Counter() @cached_property def cumulative_complexity_stat(self) -> Counter[str]: + """Returns the aggregate sum complexity statistic from the subtree rooted at this + expression node. Statistic of current node is included in the final aggregate. + """ children = self.children or [] return sum( (child.cumulative_complexity_stat for child in children), @@ -143,7 +149,7 @@ def cumulative_complexity_stat(self) -> Counter[str]: return ( sum( (expr.cumulative_complexity_stat for expr in self.expressions), - Counter({}), + Counter(), ) + self.individual_complexity_stat ) @@ -169,7 +175,7 @@ def cumulative_complexity_stat(self) -> Counter[str]: + self.individual_complexity_stat + sum( (expr.cumulative_complexity_stat for expr in self.values), - Counter({}), + Counter(), ) ) @@ -220,7 +226,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: if self.expressions: - return Counter({}) + return Counter() # if there are no expressions, we assign column value = 1 to Star return Counter({ComplexityStat.COLUMN.value: 1}) @@ -228,7 +234,7 @@ def individual_complexity_stat(self) -> Counter[str]: def cumulative_complexity_stat(self) -> Counter[str]: return self.individual_complexity_stat + sum( (child.individual_complexity_stat for child in self.expressions), - Counter({}), + Counter(), ) @@ -515,7 +521,7 @@ def cumulative_complexity_stat(self) -> Counter[str]: return ( sum( (col.cumulative_complexity_stat for col in self.order_by_cols), - Counter({}), + Counter(), ) + self.individual_complexity_stat + self.expr.cumulative_complexity_stat @@ -551,12 +557,10 @@ def cumulative_complexity_stat(self) -> Counter[str]: condition.cumulative_complexity_stat + value.cumulative_complexity_stat for condition, value in self.branches ), - Counter({}), + Counter(), ) estimate += ( - self.else_value.cumulative_complexity_stat - if self.else_value - else Counter({}) + self.else_value.cumulative_complexity_stat if self.else_value else Counter() ) return estimate diff --git a/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py b/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py index 8af1756dd9a..a27b4f8c043 100644 --- a/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py @@ -20,6 +20,14 @@ class Counter(collections.Counter, typing.Counter[KT]): class ComplexityStat(Enum): + """This enum class is used to account for different types of sql + text generated by expressions and logical plan nodes. A bottom up + aggregation of the number of occurrences of each enum type is + done in Expression and LogicalPlan class to calculate and estimate + of overall query complexity in the context of compiling for the + generated sql. + """ + FILTER = "filter" ORDER_BY = "order_by" JOIN = "join" diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 5a1238f3d22..4d1c3e71ad3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -444,8 +444,8 @@ def schema_query(self) -> str: @property def individual_complexity_stat(self) -> Counter[str]: if self.pre_actions: - # having pre-actions implies we have a non-select query followed by a - # SELECT * FROM table(result_scan(query_id)) statement + # Currently having pre-actions implies we have a non-select query followed + # by a SELECT * FROM table(result_scan(query_id)) statement return Counter({ComplexityStat.COLUMN.value: 1}) # no pre-action implies the best estimate we have is of # active columns diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index d3e8d01279e..b1f202763c0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -31,10 +31,16 @@ def __init__(self) -> None: @property def individual_complexity_stat(self) -> Counter[str]: + """Returns the individual contribution of the logical plan node towards the + overall compilation complexity of the generated sql. + """ return Counter() @property def cumulative_complexity_stat(self) -> Counter[str]: + """Returns the aggregate sum complexity statistic from the subtree rooted at this + logical plan node. Statistic of current node is included in the final aggregate. + """ if self._cumulative_complexity_stat is None: estimate = self.individual_complexity_stat for node in self.children: diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 507e011b4dd..20ce2e4c336 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -87,13 +87,15 @@ def individual_complexity_stat(self) -> Counter[str]: # LIMIT 1 estimate += Counter({ComplexityStat.LOW_IMPACT.value: 1}) - get_complexity_stat = ( - lambda expr: expr.cumulative_complexity_stat - if hasattr(expr, "cumulative_complexity_stat") - else Counter({ComplexityStat.COLUMN.value: 1}) - ) estimate += sum( - (get_complexity_stat(expr) for expr in self.aggregate_expressions), + ( + getattr( + expr, + "cumulative_complexity_stat", + Counter({ComplexityStat.COLUMN.value: 1}), + ) + for expr in self.aggregate_expressions + ), Counter(), ) return estimate @@ -221,12 +223,17 @@ def individual_complexity_stat(self) -> Counter[str]: if not self.project_list: return Counter({ComplexityStat.COLUMN.value: 1}) - get_complexity_stat = ( - lambda col: col.cumulative_complexity_stat - if hasattr(col, "cumulative_complexity_stat") - else Counter({ComplexityStat.COLUMN.value: 1}) + return sum( + ( + getattr( + col, + "cumulative_complexity_stat", + Counter({ComplexityStat.COLUMN.value: 1}), + ) + for col in self.project_list + ), + Counter(), ) - return sum((get_complexity_stat(col) for col in self.project_list), Counter()) class ViewType: From ed83869cb1338cc43a0a5c107160635b93ea36b5 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 08:12:40 -0700 Subject: [PATCH 18/37] fix bad imports --- src/snowflake/snowpark/_internal/analyzer/select_statement.py | 4 ++++ src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py | 2 +- src/snowflake/snowpark/_internal/analyzer/sort_expression.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 4d1c3e71ad3..9c113064d86 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -515,6 +515,10 @@ def schema_query(self) -> str: def query_params(self) -> Optional[Sequence[Any]]: return self._query_params + @property + def individual_complexity_stat(self) -> Counter[str]: + return self.snowflake_plan.individual_complexity_stat + class SelectStatement(Selectable): """The main logic plan to be used by a DataFrame. diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 5d9c5fad144..5825f0ba7c0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -21,7 +21,7 @@ Union, ) -from snowflake.snowpark._internal.analyzer.sort_expression import Counter +from snowflake.snowpark._internal.analyzer.materialization_utils import Counter from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, TableFunctionRelation, diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 5b21e7a65a5..2d108a12123 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -9,7 +9,7 @@ Expression, derive_dependent_columns, ) -from snowflake.snowpark._internal.analyzer.table_merge_expression import Counter +from snowflake.snowpark._internal.analyzer.materialization_utils import Counter class NullOrdering: From 026e428b32902e218d9fc3b4e38bfde625f82384 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 08:46:20 -0700 Subject: [PATCH 19/37] fix type hints --- src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 20ce2e4c336..86350229853 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -93,7 +93,7 @@ def individual_complexity_stat(self) -> Counter[str]: expr, "cumulative_complexity_stat", Counter({ComplexityStat.COLUMN.value: 1}), - ) + ) # type: ignore for expr in self.aggregate_expressions ), Counter(), @@ -229,7 +229,7 @@ def individual_complexity_stat(self) -> Counter[str]: col, "cumulative_complexity_stat", Counter({ComplexityStat.COLUMN.value: 1}), - ) + ) # type: ignore for col in self.project_list ), Counter(), From df9bfac9480e2f4acfbb0b6db7f9f2edd187fdfa Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 09:53:52 -0700 Subject: [PATCH 20/37] rename --- .../_internal/analyzer/binary_expression.py | 4 +- .../_internal/analyzer/binary_plan_node.py | 8 +- .../snowpark/_internal/analyzer/expression.py | 39 ++--- .../_internal/analyzer/grouping_set.py | 6 +- .../analyzer/materialization_utils.py | 10 +- .../_internal/analyzer/select_statement.py | 20 +-- .../_internal/analyzer/snowflake_plan_node.py | 24 +-- .../_internal/analyzer/sort_expression.py | 4 - .../_internal/analyzer/table_function.py | 18 +- .../analyzer/table_merge_expression.py | 4 +- .../_internal/analyzer/unary_expression.py | 6 +- .../_internal/analyzer/unary_plan_node.py | 34 ++-- .../_internal/analyzer/window_expression.py | 18 +- tests/integ/test_materialization_suite.py | 164 +++++++++--------- 14 files changed, 175 insertions(+), 184 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 0d346ab9783..0aefb87c73b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -9,7 +9,7 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) @@ -32,7 +32,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) class BinaryArithmeticExpression(BinaryExpression): diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index 7b10c361941..5967d6c59e3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -6,7 +6,7 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan @@ -76,7 +76,7 @@ class SetOperation(BinaryNode): @property def individual_complexity_stat(self) -> Counter[str]: # (left) operator (right) - return Counter({ComplexityStat.SET_OPERATION.value: 1}) + return Counter({PlanNodeCategory.SET_OPERATION.value: 1}) class Except(SetOperation): @@ -198,10 +198,10 @@ def sql(self) -> str: @property def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond - estimate = Counter({ComplexityStat.JOIN.value: 1}) + estimate = Counter({PlanNodeCategory.JOIN.value: 1}) if isinstance(self.join_type, UsingJoin) and self.join_type.using_columns: estimate += Counter( - {ComplexityStat.COLUMN.value: len(self.join_type.using_columns)} + {PlanNodeCategory.COLUMN.value: len(self.join_type.using_columns)} ) estimate += ( self.join_condition.cumulative_complexity_stat diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index eea8846632c..c50c91ea812 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -9,7 +9,7 @@ import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) @@ -166,7 +166,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.IN.value: 1}) + return Counter({PlanNodeCategory.IN.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -209,7 +209,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.COLUMN.value: 1}) + return Counter({PlanNodeCategory.COLUMN.value: 1}) class Star(Expression): @@ -228,7 +228,7 @@ def individual_complexity_stat(self) -> Counter[str]: if self.expressions: return Counter() # if there are no expressions, we assign column value = 1 to Star - return Counter({ComplexityStat.COLUMN.value: 1}) + return Counter({PlanNodeCategory.COLUMN.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -273,7 +273,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.COLUMN.value: 1}) + return Counter({PlanNodeCategory.COLUMN.value: 1}) class Literal(Expression): @@ -300,7 +300,7 @@ def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.LITERAL.value: 1}) + return Counter({PlanNodeCategory.LITERAL.value: 1}) class Interval(Expression): @@ -352,12 +352,7 @@ def __str__(self) -> str: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter( - { - ComplexityStat.LITERAL.value: 2 * len(self.values_dict), - ComplexityStat.LOW_IMPACT.value: 1, - } - ) + return Counter({PlanNodeCategory.OTHERS.value: 1}) class Like(Expression): @@ -372,7 +367,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: # expr LIKE pattern - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -395,7 +390,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: # expr REG_EXP pattern - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -418,7 +413,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: # expr COLLATE collate_spec - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -437,7 +432,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: # the literal corresponds to the contribution from self.field - return Counter({ComplexityStat.LITERAL.value: 1}) + return Counter({PlanNodeCategory.LITERAL.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -457,7 +452,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: # the literal corresponds to the contribution from self.field - return Counter({ComplexityStat.LITERAL.value: 1}) + return Counter({PlanNodeCategory.LITERAL.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -498,7 +493,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.FUNCTION.value: 1}) + return Counter({PlanNodeCategory.FUNCTION.value: 1}) class WithinGroup(Expression): @@ -514,7 +509,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: # expr WITHIN GROUP (ORDER BY cols) - return Counter({ComplexityStat.ORDER_BY.value: 1}) + return Counter({PlanNodeCategory.ORDER_BY.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -548,7 +543,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.CASE_WHEN.value: 1}) + return Counter({PlanNodeCategory.CASE_WHEN.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -586,7 +581,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.FUNCTION.value: 1}) + return Counter({PlanNodeCategory.FUNCTION.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -608,7 +603,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.FUNCTION.value: 1}) + return Counter({PlanNodeCategory.FUNCTION.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 4bf7a227ae8..6de4ba4fc46 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -10,7 +10,7 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) @@ -26,7 +26,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) class Cube(GroupingSet): @@ -58,4 +58,4 @@ def cumulative_complexity_stat(self) -> Counter[str]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) diff --git a/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py b/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py index a27b4f8c043..cbd8abe72ae 100644 --- a/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py @@ -19,7 +19,7 @@ class Counter(collections.Counter, typing.Counter[KT]): from collections import Counter # noqa -class ComplexityStat(Enum): +class PlanNodeCategory(Enum): """This enum class is used to account for different types of sql text generated by expressions and logical plan nodes. A bottom up aggregation of the number of occurrences of each enum type is @@ -39,8 +39,8 @@ class ComplexityStat(Enum): GROUP_BY = "group_by" PARTITION_BY = "partition_by" CASE_WHEN = "case_when" - LITERAL = "literal" - COLUMN = "column" - FUNCTION = "function" + LITERAL = "literal" # cover all literals like numbers, constant strings, etc + COLUMN = "column" # covers all cases where a table column is referred + FUNCTION = "function" # cover all snowflake built-in function, table functions and UDXFs IN = "in" - LOW_IMPACT = "low_impact" + OTHERS = "others" diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 9c113064d86..95f048339d4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -23,7 +23,7 @@ import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.cte_utils import encode_id from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) from snowflake.snowpark._internal.analyzer.table_function import ( @@ -303,7 +303,7 @@ def individual_complexity_stat(self) -> Counter[str]: """ if isinstance(self.snowflake_plan.source_plan, Selectable): return Counter( - {ComplexityStat.COLUMN.value: len(self.column_states.active_columns)} + {PlanNodeCategory.COLUMN.value: len(self.column_states.active_columns)} ) return self.snowflake_plan.individual_complexity_stat @@ -384,7 +384,7 @@ def schema_query(self) -> str: @property def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM entity - return Counter({ComplexityStat.COLUMN.value: 1}) + return Counter({PlanNodeCategory.COLUMN.value: 1}) @property def query_params(self) -> Optional[Sequence[Any]]: @@ -446,11 +446,11 @@ def individual_complexity_stat(self) -> Counter[str]: if self.pre_actions: # Currently having pre-actions implies we have a non-select query followed # by a SELECT * FROM table(result_scan(query_id)) statement - return Counter({ComplexityStat.COLUMN.value: 1}) + return Counter({PlanNodeCategory.COLUMN.value: 1}) # no pre-action implies the best estimate we have is of # active columns return Counter( - {ComplexityStat.COLUMN.value: len(self.column_states.active_columns)} + {PlanNodeCategory.COLUMN.value: len(self.column_states.active_columns)} ) def to_subqueryable(self) -> "SelectSQL": @@ -726,7 +726,7 @@ def individual_complexity_stat(self) -> Counter[str]: # filter component - add +1 for WHERE clause and sum of expression complexity for where expression estimate += ( - Counter({ComplexityStat.FILTER.value: 1}) + Counter({PlanNodeCategory.FILTER.value: 1}) + self.where.cumulative_complexity_stat if self.where else Counter() @@ -736,7 +736,7 @@ def individual_complexity_stat(self) -> Counter[str]: estimate += ( sum( (expr.cumulative_complexity_stat for expr in self.order_by), - Counter({ComplexityStat.ORDER_BY.value: 1}), + Counter({PlanNodeCategory.ORDER_BY.value: 1}), ) if self.order_by else Counter() @@ -744,10 +744,10 @@ def individual_complexity_stat(self) -> Counter[str]: # limit/offset component estimate += ( - Counter({ComplexityStat.LOW_IMPACT.value: 1}) if self.limit_ else Counter() + Counter({PlanNodeCategory.OTHERS.value: 1}) if self.limit_ else Counter() ) estimate += ( - Counter({ComplexityStat.LOW_IMPACT.value: 1}) if self.offset else Counter() + Counter({PlanNodeCategory.OTHERS.value: 1}) if self.offset else Counter() ) return estimate @@ -1135,7 +1135,7 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: @property def individual_complexity_stat(self) -> Counter[str]: # we add #set_operands - 1 additional operators in sql query - return Counter({ComplexityStat.SET_OPERATION.value: len(self.set_operands) - 1}) + return Counter({PlanNodeCategory.SET_OPERATION.value: len(self.set_operands) - 1}) class DeriveColumnDependencyError(Exception): diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index b1f202763c0..c220bf973c3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -9,7 +9,7 @@ from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) from snowflake.snowpark.row import Row @@ -73,11 +73,11 @@ def individual_complexity_stat(self) -> Counter[str]: # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (step) + (start) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => count))) return Counter( { - ComplexityStat.WINDOW.value: 1, - ComplexityStat.ORDER_BY.value: 1, - ComplexityStat.LITERAL.value: 3, # step, start, count - ComplexityStat.COLUMN.value: 1, # id column - ComplexityStat.LOW_IMPACT.value: 2, # ROW_NUMBER, GENERATOR + PlanNodeCategory.WINDOW.value: 1, + PlanNodeCategory.ORDER_BY.value: 1, + PlanNodeCategory.LITERAL.value: 3, # step, start, count + PlanNodeCategory.COLUMN.value: 1, # id column + PlanNodeCategory.OTHERS.value: 2, # ROW_NUMBER, GENERATOR } ) @@ -90,7 +90,7 @@ def __init__(self, name: str) -> None: @property def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM name - return Counter({ComplexityStat.COLUMN.value: 1}) + return Counter({PlanNodeCategory.COLUMN.value: 1}) class SnowflakeValues(LeafNode): @@ -111,8 +111,8 @@ def individual_complexity_stat(self) -> Counter[str]: # TODO: use ARRAY_BIND_THRESHOLD return Counter( { - ComplexityStat.COLUMN.value: len(self.output), - ComplexityStat.LITERAL.value: len(self.data) * len(self.output), + PlanNodeCategory.COLUMN.value: len(self.output), + PlanNodeCategory.LITERAL.value: len(self.data) * len(self.output), } ) @@ -150,10 +150,10 @@ def __init__( def individual_complexity_stat(self) -> Counter[str]: # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (child) estimate = Counter( - {ComplexityStat.LOW_IMPACT.value: 1, ComplexityStat.COLUMN.value: 1} + {PlanNodeCategory.OTHERS.value: 1, PlanNodeCategory.COLUMN.value: 1} ) estimate += ( - Counter({ComplexityStat.COLUMN.value: len(self.column_names)}) + Counter({PlanNodeCategory.COLUMN.value: len(self.column_names)}) if self.column_names else Counter() ) @@ -182,7 +182,7 @@ def __init__( def individual_complexity_stat(self) -> Counter[str]: # for limit and offset return ( - Counter({ComplexityStat.LOW_IMPACT.value: 2}) + Counter({PlanNodeCategory.OTHERS.value: 2}) + self.limit_expr.cumulative_complexity_stat + self.offset_expr.cumulative_complexity_stat ) diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 2d108a12123..bfbe4fc64ad 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -58,10 +58,6 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) - @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter() - @cached_property def cumulative_complexity_stat(self) -> Counter[str]: return self.child.cumulative_complexity_stat + self.individual_complexity_stat diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index cfc20244e4b..415c5d4d477 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -8,7 +8,7 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan @@ -39,11 +39,11 @@ def __init__( def cumulative_complexity_stat(self) -> Counter[str]: if not self.over: return Counter() - estimate = Counter({ComplexityStat.WINDOW.value: 1}) + estimate = Counter({PlanNodeCategory.WINDOW.value: 1}) estimate += ( sum( (expr.cumulative_complexity_stat for expr in self.partition_spec), - Counter({ComplexityStat.PARTITION_BY.value: 1}), + Counter({PlanNodeCategory.PARTITION_BY.value: 1}), ) if self.partition_spec else Counter() @@ -51,7 +51,7 @@ def cumulative_complexity_stat(self) -> Counter[str]: estimate += ( sum( (expr.cumulative_complexity_stat for expr in self.order_spec), - Counter({ComplexityStat.ORDER_BY.value: 1}), + Counter({PlanNodeCategory.ORDER_BY.value: 1}), ) if self.order_spec else Counter() @@ -75,7 +75,7 @@ def __init__( @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.FUNCTION.value: 1}) + return Counter({PlanNodeCategory.FUNCTION.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -168,7 +168,7 @@ def cumulative_complexity_stat(self) -> Counter[str]: if self.partition_spec else Counter() ) - estimate += Counter({ComplexityStat.COLUMN.value: len(self.operators)}) + estimate += Counter({PlanNodeCategory.COLUMN.value: len(self.operators)}) return estimate @@ -203,9 +203,9 @@ def individual_complexity_stat(self) -> Counter[str]: return ( Counter( { - ComplexityStat.COLUMN.value: len(self.left_cols) + PlanNodeCategory.COLUMN.value: len(self.left_cols) + len(self.right_cols), - ComplexityStat.JOIN.value: 1, + PlanNodeCategory.JOIN.value: 1, } ) + self.table_function.cumulative_complexity_stat @@ -224,6 +224,6 @@ def __init__( def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (child), LATERAL table_func_expression return ( - Counter({ComplexityStat.COLUMN.value: 1}) + Counter({PlanNodeCategory.COLUMN.value: 1}) + self.table_function.cumulative_complexity_stat ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index 6e8111e6faf..683687c013d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -7,7 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( @@ -23,7 +23,7 @@ def __init__(self, condition: Optional[Expression]) -> None: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index a5c05c4c05e..51656310da8 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -10,7 +10,7 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) from snowflake.snowpark.types import DataType @@ -39,7 +39,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) class Cast(UnaryExpression): @@ -91,7 +91,7 @@ def __str__(self): @property def individual_complexity_stat(self) -> Counter[str]: # child AS name - return Counter({ComplexityStat.COLUMN.value: 1}) + return Counter({PlanNodeCategory.COLUMN.value: 1}) class UnresolvedAlias(UnaryExpression, NamedExpression): diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 86350229853..91e72a968bf 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -10,7 +10,7 @@ ScalarSubquery, ) from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan @@ -43,9 +43,9 @@ def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (child) SAMPLE (row_count ROWS) -- if not probability but row count is provided return Counter( { - ComplexityStat.SAMPLE.value: 1, - ComplexityStat.LITERAL.value: 1, - ComplexityStat.COLUMN.value: 1, + PlanNodeCategory.SAMPLE.value: 1, + PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.COLUMN.value: 1, } ) @@ -58,7 +58,7 @@ def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None: @property def individual_complexity_stat(self) -> Counter[str]: # child ORDER BY COMMA.join(order) - return Counter({ComplexityStat.ORDER_BY.value: 1}) + sum( + return Counter({PlanNodeCategory.ORDER_BY.value: 1}) + sum( (col.cumulative_complexity_stat for col in self.order), Counter() ) @@ -79,20 +79,20 @@ def individual_complexity_stat(self) -> Counter[str]: estimate = Counter() if self.grouping_expressions: # GROUP BY grouping_exprs - estimate += Counter({ComplexityStat.GROUP_BY.value: 1}) + sum( + estimate += Counter({PlanNodeCategory.GROUP_BY.value: 1}) + sum( (expr.cumulative_complexity_stat for expr in self.grouping_expressions), Counter(), ) else: # LIMIT 1 - estimate += Counter({ComplexityStat.LOW_IMPACT.value: 1}) + estimate += Counter({PlanNodeCategory.OTHERS.value: 1}) estimate += sum( ( getattr( expr, "cumulative_complexity_stat", - Counter({ComplexityStat.COLUMN.value: 1}), + Counter({PlanNodeCategory.COLUMN.value: 1}), ) # type: ignore for expr in self.aggregate_expressions ), @@ -139,8 +139,8 @@ def individual_complexity_stat(self) -> Counter[str]: (val.cumulative_complexity_stat for val in self.pivot_values), Counter() ) else: - # if pivot values is None, then we add LOW_IMPACT for ANY - estimate += Counter({ComplexityStat.LOW_IMPACT.value: 1}) + # if pivot values is None, then we add OTHERS for ANY + estimate += Counter({PlanNodeCategory.OTHERS.value: 1}) # aggregate estimate estimate += sum( @@ -149,7 +149,7 @@ def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) estimate += Counter( - {ComplexityStat.COLUMN.value: 2, ComplexityStat.PIVOT.value: 1} + {PlanNodeCategory.COLUMN.value: 2, PlanNodeCategory.PIVOT.value: 1} ) return estimate @@ -171,7 +171,7 @@ def __init__( def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) estimate = Counter( - {ComplexityStat.UNPIVOT.value: 1, ComplexityStat.COLUMN.value: 3} + {PlanNodeCategory.UNPIVOT.value: 1, PlanNodeCategory.COLUMN.value: 3} ) estimate += sum( (expr.cumulative_complexity_stat for expr in self.column_list), Counter() @@ -193,8 +193,8 @@ def individual_complexity_stat(self) -> Counter[str]: # SELECT * RENAME (before AS after, ...) FROM child return Counter( { - ComplexityStat.COLUMN.value: 1 + 2 * len(self.column_map), - ComplexityStat.LOW_IMPACT.value: 1 + len(self.column_map), + PlanNodeCategory.COLUMN.value: 1 + 2 * len(self.column_map), + PlanNodeCategory.OTHERS.value: 1 + len(self.column_map), } ) @@ -208,7 +208,7 @@ def __init__(self, condition: Expression, child: LogicalPlan) -> None: def individual_complexity_stat(self) -> Counter[str]: # child WHERE condition return ( - Counter({ComplexityStat.FILTER.value: 1}) + Counter({PlanNodeCategory.FILTER.value: 1}) + self.condition.cumulative_complexity_stat ) @@ -221,14 +221,14 @@ def __init__(self, project_list: List[NamedExpression], child: LogicalPlan) -> N @property def individual_complexity_stat(self) -> Counter[str]: if not self.project_list: - return Counter({ComplexityStat.COLUMN.value: 1}) + return Counter({PlanNodeCategory.COLUMN.value: 1}) return sum( ( getattr( col, "cumulative_complexity_stat", - Counter({ComplexityStat.COLUMN.value: 1}), + Counter({PlanNodeCategory.COLUMN.value: 1}), ) # type: ignore for col in self.project_list ), diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 82677eec59f..140f1a1519d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -10,7 +10,7 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -24,7 +24,7 @@ def __init__(self) -> None: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) class UnboundedPreceding(SpecialFrameBoundary): @@ -74,7 +74,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.LOW_IMPACT.value: 1}) + return Counter({PlanNodeCategory.OTHERS.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -107,12 +107,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def individual_complexity_stat(self) -> Counter[str]: estimate = Counter() estimate += ( - Counter({ComplexityStat.PARTITION_BY.value: 1}) + Counter({PlanNodeCategory.PARTITION_BY.value: 1}) if self.partition_spec else Counter() ) estimate += ( - Counter({ComplexityStat.ORDER_BY.value: 1}) + Counter({PlanNodeCategory.ORDER_BY.value: 1}) if self.order_spec else Counter() ) @@ -147,7 +147,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({ComplexityStat.WINDOW.value: 1}) + return Counter({PlanNodeCategory.WINDOW.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -181,14 +181,14 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: # for func_name - estimate = Counter({ComplexityStat.FUNCTION.value: 1}) + estimate = Counter({PlanNodeCategory.FUNCTION.value: 1}) # for offset estimate += ( - Counter({ComplexityStat.LITERAL.value: 1}) if self.offset else Counter() + Counter({PlanNodeCategory.LITERAL.value: 1}) if self.offset else Counter() ) # for ignore nulls estimate += ( - Counter({ComplexityStat.LOW_IMPACT.value: 1}) + Counter({PlanNodeCategory.OTHERS.value: 1}) if self.ignore_nulls else Counter() ) diff --git a/tests/integ/test_materialization_suite.py b/tests/integ/test_materialization_suite.py index ed7c0846778..75bb02ef0c0 100644 --- a/tests/integ/test_materialization_suite.py +++ b/tests/integ/test_materialization_suite.py @@ -6,7 +6,7 @@ import pytest from snowflake.snowpark._internal.analyzer.materialization_utils import ( - ComplexityStat, + PlanNodeCategory, Counter, ) from snowflake.snowpark._internal.analyzer.select_statement import ( @@ -65,20 +65,20 @@ def test_create_dataframe_from_values(session: Session): df1 = session.create_dataframe([[1], [2], [3]], schema=["a"]) # SELECT "A" FROM ( SELECT $1 AS "A" FROM VALUES (1 :: INT), (2 :: INT), (3 :: INT)) assert_df_subtree_query_complexity( - df1, {ComplexityStat.LITERAL.value: 3, ComplexityStat.COLUMN.value: 2} + df1, {PlanNodeCategory.LITERAL.value: 3, PlanNodeCategory.COLUMN.value: 2} ) df2 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"]) # SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT), (5 :: INT, 6 :: INT)) assert_df_subtree_query_complexity( - df2, {ComplexityStat.LITERAL.value: 6, ComplexityStat.COLUMN.value: 4} + df2, {PlanNodeCategory.LITERAL.value: 6, PlanNodeCategory.COLUMN.value: 4} ) def test_session_table(session: Session, sample_table: str): df = session.table(sample_table) # select * from sample_table - assert_df_subtree_query_complexity(df, {ComplexityStat.COLUMN.value: 1}) + assert_df_subtree_query_complexity(df, {PlanNodeCategory.COLUMN.value: 1}) def test_range_statement(session: Session): @@ -87,11 +87,11 @@ def test_range_statement(session: Session): assert_df_subtree_query_complexity( df, { - ComplexityStat.COLUMN.value: 1, - ComplexityStat.LITERAL.value: 3, - ComplexityStat.LOW_IMPACT.value: 2, - ComplexityStat.ORDER_BY.value: 1, - ComplexityStat.WINDOW.value: 1, + PlanNodeCategory.COLUMN.value: 1, + PlanNodeCategory.LITERAL.value: 3, + PlanNodeCategory.OTHERS.value: 2, + PlanNodeCategory.ORDER_BY.value: 1, + PlanNodeCategory.WINDOW.value: 1, }, ) @@ -103,9 +103,9 @@ def test_generator_table_function(session: Session): assert_df_subtree_query_complexity( df1, { - ComplexityStat.COLUMN.value: 2, - ComplexityStat.FUNCTION.value: 1, - ComplexityStat.LITERAL.value: 1, + PlanNodeCategory.COLUMN.value: 2, + PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.LITERAL.value: 1, }, ) @@ -114,7 +114,7 @@ def test_generator_table_function(session: Session): assert_df_subtree_query_complexity( df2, get_cumulative_complexity_stat(df1) - + Counter({ComplexityStat.ORDER_BY.value: 1, ComplexityStat.COLUMN.value: 1}), + + Counter({PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.COLUMN.value: 1}), ) @@ -123,7 +123,7 @@ def test_join_table_function(session: Session): "select 'James' as name, 'address1 address2 address3' as addresses" ) # SelectSQL chooses num active columns as the best estimate - assert_df_subtree_query_complexity(df1, {ComplexityStat.COLUMN.value: 2}) + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN.value: 2}) split_to_table = table_function("split_to_table") df2 = df1.select(split_to_table(col("addresses"), lit(" "))) @@ -134,10 +134,10 @@ def test_join_table_function(session: Session): assert_df_subtree_query_complexity( df2, { - ComplexityStat.COLUMN.value: 9, - ComplexityStat.JOIN.value: 1, - ComplexityStat.FUNCTION.value: 1, - ComplexityStat.LITERAL.value: 1, + PlanNodeCategory.COLUMN.value: 9, + PlanNodeCategory.JOIN.value: 1, + PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.LITERAL.value: 1, }, ) @@ -159,7 +159,7 @@ def test_set_operators(session: Session, sample_table: str, set_operator: str): # ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) set_operator ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) assert_df_subtree_query_complexity( - df, {ComplexityStat.COLUMN.value: 2, ComplexityStat.SET_OPERATION.value: 1} + df, {PlanNodeCategory.COLUMN.value: 2, PlanNodeCategory.SET_OPERATION.value: 1} ) @@ -174,38 +174,38 @@ def test_agg(session: Session, sample_table: str): assert_df_subtree_query_complexity( df1, { - ComplexityStat.COLUMN.value: 3, - ComplexityStat.LOW_IMPACT.value: 1, - ComplexityStat.FUNCTION.value: 1, + PlanNodeCategory.COLUMN.value: 3, + PlanNodeCategory.OTHERS.value: 1, + PlanNodeCategory.FUNCTION.value: 1, }, ) # SELECT (avg("A") + 1 :: INT) AS "ADD(AVG(A), LITERAL())" FROM ( SELECT * FROM sample_table) LIMIT 1 assert_df_subtree_query_complexity( df2, { - ComplexityStat.COLUMN.value: 3, - ComplexityStat.LOW_IMPACT.value: 2, - ComplexityStat.FUNCTION.value: 1, - ComplexityStat.LITERAL.value: 1, + PlanNodeCategory.COLUMN.value: 3, + PlanNodeCategory.OTHERS.value: 2, + PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.LITERAL.value: 1, }, ) # SELECT avg("A") AS "AVG(A)", avg(("B" + 1 :: INT)) AS "AVG_B" FROM ( SELECT * FROM sample_table) LIMIT 1 assert_df_subtree_query_complexity( df3, { - ComplexityStat.COLUMN.value: 5, - ComplexityStat.LOW_IMPACT.value: 2, - ComplexityStat.FUNCTION.value: 2, - ComplexityStat.LITERAL.value: 1, + PlanNodeCategory.COLUMN.value: 5, + PlanNodeCategory.OTHERS.value: 2, + PlanNodeCategory.FUNCTION.value: 2, + PlanNodeCategory.LITERAL.value: 1, }, ) # SELECT "A", "B", avg("C") AS "AVG(C)" FROM ( SELECT * FROM SNOWPARK_TEMP_TABLE_EV1NO4AID6) GROUP BY "A", "B" assert_df_subtree_query_complexity( df4, { - ComplexityStat.COLUMN.value: 7, - ComplexityStat.GROUP_BY.value: 1, - ComplexityStat.FUNCTION.value: 1, + PlanNodeCategory.COLUMN.value: 7, + PlanNodeCategory.GROUP_BY.value: 1, + PlanNodeCategory.FUNCTION.value: 1, }, ) @@ -232,13 +232,13 @@ def test_window_function(session: Session): df1, Counter( { - ComplexityStat.PARTITION_BY.value: 1, - ComplexityStat.ORDER_BY.value: 1, - ComplexityStat.WINDOW.value: 1, - ComplexityStat.FUNCTION.value: 1, - ComplexityStat.COLUMN.value: 5, - ComplexityStat.LITERAL.value: 1, - ComplexityStat.LOW_IMPACT.value: 2, + PlanNodeCategory.PARTITION_BY.value: 1, + PlanNodeCategory.ORDER_BY.value: 1, + PlanNodeCategory.WINDOW.value: 1, + PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.COLUMN.value: 5, + PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.OTHERS.value: 2, } ), ) @@ -251,11 +251,11 @@ def test_window_function(session: Session): get_cumulative_complexity_stat(df1) + Counter( { - ComplexityStat.ORDER_BY.value: 1, - ComplexityStat.WINDOW.value: 1, - ComplexityStat.FUNCTION.value: 1, - ComplexityStat.COLUMN.value: 3, - ComplexityStat.LOW_IMPACT.value: 3, + PlanNodeCategory.ORDER_BY.value: 1, + PlanNodeCategory.WINDOW.value: 1, + PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.COLUMN.value: 3, + PlanNodeCategory.OTHERS.value: 3, } ), ) @@ -266,11 +266,11 @@ def test_window_function(session: Session): def test_join_statement(session: Session, sample_table: str): # SELECT * FROM table df1 = session.table(sample_table) - assert_df_subtree_query_complexity(df1, {ComplexityStat.COLUMN.value: 1}) + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN.value: 1}) # SELECT A, B, E FROM (SELECT $1 AS "A", $2 AS "B", $3 AS "E" FROM VALUES (1 :: INT, 2 :: INT, 5 :: INT), (3 :: INT, 4 :: INT, 9 :: INT)) df2 = session.create_dataframe([[1, 2, 5], [3, 4, 9]], schema=["a", "b", "e"]) assert_df_subtree_query_complexity( - df2, {ComplexityStat.COLUMN.value: 6, ComplexityStat.LITERAL.value: 6} + df2, {PlanNodeCategory.COLUMN.value: 6, PlanNodeCategory.LITERAL.value: 6} ) df3 = df1.join(df2) @@ -281,9 +281,9 @@ def test_join_statement(session: Session, sample_table: str): assert_df_subtree_query_complexity( df3, { - ComplexityStat.COLUMN.value: 18, - ComplexityStat.LITERAL.value: 6, - ComplexityStat.JOIN.value: 1, + PlanNodeCategory.COLUMN.value: 18, + PlanNodeCategory.LITERAL.value: 6, + PlanNodeCategory.JOIN.value: 1, }, ) @@ -292,14 +292,14 @@ def test_join_statement(session: Session, sample_table: str): assert_df_subtree_query_complexity( df4, get_cumulative_complexity_stat(df3) - + Counter({ComplexityStat.COLUMN.value: 4, ComplexityStat.LOW_IMPACT.value: 3}), + + Counter({PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.OTHERS.value: 3}), ) df5 = df1.join(df2, using_columns=["a", "b"]) # SELECT * FROM ( (ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT USING (a, b)) assert_df_subtree_query_complexity( df5, - get_cumulative_complexity_stat(df3) + Counter({ComplexityStat.COLUMN.value: 2}), + get_cumulative_complexity_stat(df3) + Counter({PlanNodeCategory.COLUMN.value: 2}), ) @@ -324,10 +324,10 @@ def test_pivot(session: Session): assert_df_subtree_query_complexity( df_pivot1, { - ComplexityStat.PIVOT.value: 1, - ComplexityStat.COLUMN.value: 4, - ComplexityStat.LITERAL.value: 2, - ComplexityStat.FUNCTION.value: 1, + PlanNodeCategory.PIVOT.value: 1, + PlanNodeCategory.COLUMN.value: 4, + PlanNodeCategory.LITERAL.value: 2, + PlanNodeCategory.FUNCTION.value: 1, }, ) @@ -340,10 +340,10 @@ def test_pivot(session: Session): assert_df_subtree_query_complexity( df_pivot2, { - ComplexityStat.PIVOT.value: 1, - ComplexityStat.COLUMN.value: 4, - ComplexityStat.LITERAL.value: 3, - ComplexityStat.FUNCTION.value: 1, + PlanNodeCategory.PIVOT.value: 1, + PlanNodeCategory.COLUMN.value: 4, + PlanNodeCategory.LITERAL.value: 3, + PlanNodeCategory.FUNCTION.value: 1, }, ) finally: @@ -365,7 +365,7 @@ def test_unpivot(session: Session): # SELECT * FROM ( SELECT * FROM (sales_for_month)) UNPIVOT (sales FOR month IN ("JAN", "FEB")) assert_df_subtree_query_complexity( df_unpivot1, - {ComplexityStat.UNPIVOT.value: 1, ComplexityStat.COLUMN.value: 6}, + {PlanNodeCategory.UNPIVOT.value: 1, PlanNodeCategory.COLUMN.value: 6}, ) finally: Utils.drop_table(session, "sales_for_month") @@ -378,9 +378,9 @@ def test_sample(session: Session, sample_table): assert_df_subtree_query_complexity( df_sample_frac, { - ComplexityStat.SAMPLE.value: 1, - ComplexityStat.LITERAL.value: 1, - ComplexityStat.COLUMN.value: 2, + PlanNodeCategory.SAMPLE.value: 1, + PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.COLUMN.value: 2, }, ) @@ -389,9 +389,9 @@ def test_sample(session: Session, sample_table): assert_df_subtree_query_complexity( df_sample_rows, { - ComplexityStat.SAMPLE.value: 1, - ComplexityStat.LITERAL.value: 1, - ComplexityStat.COLUMN.value: 2, + PlanNodeCategory.SAMPLE.value: 1, + PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.COLUMN.value: 2, }, ) @@ -405,11 +405,11 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl # from select * from sample_table which is flattened out. This is a known limitation but is okay # since we are not off my much df2 = df1.select("a", "b", "c") - assert_df_subtree_query_complexity(df2, {ComplexityStat.COLUMN.value: 4}) + assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN.value: 4}) # 1 less active column df3 = df2.select("b", "c") - assert_df_subtree_query_complexity(df3, {ComplexityStat.COLUMN.value: 3}) + assert_df_subtree_query_complexity(df3, {PlanNodeCategory.COLUMN.value: 3}) # add sort # for additional ORDER BY "B" ASC NULLS FIRST @@ -417,14 +417,14 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df4, get_cumulative_complexity_stat(df3) - + Counter({ComplexityStat.COLUMN.value: 1, ComplexityStat.ORDER_BY.value: 1}), + + Counter({PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.ORDER_BY.value: 1}), ) # for additional ,"C" ASC NULLS FIRST df5 = df4.sort(col("c").desc()) assert_df_subtree_query_complexity( df5, - get_cumulative_complexity_stat(df4) + Counter({ComplexityStat.COLUMN.value: 1}), + get_cumulative_complexity_stat(df4) + Counter({PlanNodeCategory.COLUMN.value: 1}), ) # add filter @@ -435,10 +435,10 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl get_cumulative_complexity_stat(df5) + Counter( { - ComplexityStat.FILTER.value: 1, - ComplexityStat.COLUMN.value: 1, - ComplexityStat.LITERAL.value: 1, - ComplexityStat.LOW_IMPACT.value: 1, + PlanNodeCategory.FILTER.value: 1, + PlanNodeCategory.COLUMN.value: 1, + PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.OTHERS.value: 1, } ), ) @@ -450,9 +450,9 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl get_cumulative_complexity_stat(df6) + Counter( { - ComplexityStat.COLUMN.value: 1, - ComplexityStat.LITERAL.value: 1, - ComplexityStat.LOW_IMPACT.value: 2, + PlanNodeCategory.COLUMN.value: 1, + PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.OTHERS.value: 2, } ), ) @@ -463,7 +463,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df8, sum( (get_cumulative_complexity_stat(df) for df in [df3, df4, df5]), - Counter({ComplexityStat.SET_OPERATION.value: 2}), + Counter({PlanNodeCategory.SET_OPERATION.value: 2}), ), ) @@ -473,7 +473,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df9, sum( (get_cumulative_complexity_stat(df) for df in [df6, df7, df8]), - Counter({ComplexityStat.SET_OPERATION.value: 2}), + Counter({PlanNodeCategory.SET_OPERATION.value: 2}), ), ) @@ -482,7 +482,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df10, get_cumulative_complexity_stat(df9) - + Counter({ComplexityStat.LOW_IMPACT.value: 1}), + + Counter({PlanNodeCategory.OTHERS.value: 1}), ) # for offset @@ -490,5 +490,5 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df11, get_cumulative_complexity_stat(df9) - + Counter({ComplexityStat.LOW_IMPACT.value: 2}), + + Counter({PlanNodeCategory.OTHERS.value: 2}), ) From ced5d78bfb5a97094c3ae0a7586c02f57375ac84 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 09:55:07 -0700 Subject: [PATCH 21/37] rename file --- src/snowflake/snowpark/_internal/analyzer/binary_expression.py | 2 +- src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py | 2 +- src/snowflake/snowpark/_internal/analyzer/expression.py | 2 +- src/snowflake/snowpark/_internal/analyzer/grouping_set.py | 2 +- .../{materialization_utils.py => query_plan_analysis_utils.py} | 0 src/snowflake/snowpark/_internal/analyzer/select_statement.py | 2 +- src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py | 2 +- .../snowpark/_internal/analyzer/snowflake_plan_node.py | 2 +- src/snowflake/snowpark/_internal/analyzer/sort_expression.py | 2 +- src/snowflake/snowpark/_internal/analyzer/table_function.py | 2 +- .../snowpark/_internal/analyzer/table_merge_expression.py | 2 +- src/snowflake/snowpark/_internal/analyzer/unary_expression.py | 2 +- src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py | 2 +- src/snowflake/snowpark/_internal/analyzer/window_expression.py | 2 +- tests/integ/test_materialization_suite.py | 2 +- 15 files changed, 14 insertions(+), 14 deletions(-) rename src/snowflake/snowpark/_internal/analyzer/{materialization_utils.py => query_plan_analysis_utils.py} (100%) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 0aefb87c73b..189eaad6655 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -8,7 +8,7 @@ Expression, derive_dependent_columns, ) -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index 5967d6c59e3..e187900d6ea 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -5,7 +5,7 @@ from typing import List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index c50c91ea812..66945077459 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple import snowflake.snowpark._internal.utils -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 6de4ba4fc46..22ee756b74e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -9,7 +9,7 @@ Expression, derive_dependent_columns, ) -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/materialization_utils.py b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py similarity index 100% rename from src/snowflake/snowpark/_internal/analyzer/materialization_utils.py rename to src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 95f048339d4..67220574a61 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -22,7 +22,7 @@ import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.cte_utils import encode_id -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 5825f0ba7c0..9ad31e32b27 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -21,7 +21,7 @@ Union, ) -from snowflake.snowpark._internal.analyzer.materialization_utils import Counter +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import Counter from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, TableFunctionRelation, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index c220bf973c3..11bc722f065 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index bfbe4fc64ad..e766d1ed027 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -9,7 +9,7 @@ Expression, derive_dependent_columns, ) -from snowflake.snowpark._internal.analyzer.materialization_utils import Counter +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import Counter class NullOrdering: diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 415c5d4d477..d55ebf5ed51 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index 683687c013d..26e377e0c0a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index 51656310da8..00948f9e302 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -9,7 +9,7 @@ NamedExpression, derive_dependent_columns, ) -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 91e72a968bf..fedb03d9c6f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -9,7 +9,7 @@ NamedExpression, ScalarSubquery, ) -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 140f1a1519d..e62e7776ad6 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -9,7 +9,7 @@ Expression, derive_dependent_columns, ) -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) diff --git a/tests/integ/test_materialization_suite.py b/tests/integ/test_materialization_suite.py index 75bb02ef0c0..766f213170d 100644 --- a/tests/integ/test_materialization_suite.py +++ b/tests/integ/test_materialization_suite.py @@ -5,7 +5,7 @@ import pytest -from snowflake.snowpark._internal.analyzer.materialization_utils import ( +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, Counter, ) From 995f107d670873e3e18e170a1b2e0aa55396608c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 11:32:02 -0700 Subject: [PATCH 22/37] refactor --- .../_internal/analyzer/binary_expression.py | 4 +- .../_internal/analyzer/binary_plan_node.py | 14 ++-- .../snowpark/_internal/analyzer/expression.py | 75 ++++++++++--------- .../_internal/analyzer/grouping_set.py | 8 +- .../analyzer/query_plan_analysis_utils.py | 3 +- .../_internal/analyzer/select_statement.py | 30 ++++---- .../_internal/analyzer/snowflake_plan.py | 6 +- .../_internal/analyzer/snowflake_plan_node.py | 28 +++---- .../_internal/analyzer/table_function.py | 41 ++++------ .../analyzer/table_merge_expression.py | 34 ++++----- .../_internal/analyzer/unary_expression.py | 12 +-- .../_internal/analyzer/unary_plan_node.py | 42 +++++------ .../_internal/analyzer/window_expression.py | 28 +++---- src/snowflake/snowpark/_internal/telemetry.py | 4 +- ...n_suite.py => test_query_plan_analysis.py} | 42 +++++++---- tests/integ/test_telemetry.py | 14 ++-- 16 files changed, 197 insertions(+), 188 deletions(-) rename tests/integ/{test_materialization_suite.py => test_query_plan_analysis.py} (92%) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 189eaad6655..41bf92633f9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -31,8 +31,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.OTHERS.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT class BinaryArithmeticExpression(BinaryExpression): diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index e187900d6ea..b49410dbd66 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -74,9 +74,9 @@ def __init__(self, left: LogicalPlan, right: LogicalPlan) -> None: class SetOperation(BinaryNode): @property - def individual_complexity_stat(self) -> Counter[str]: + def plan_node_category(self) -> Counter[str]: # (left) operator (right) - return Counter({PlanNodeCategory.SET_OPERATION.value: 1}) + return PlanNodeCategory.SET_OPERATION class Except(SetOperation): @@ -198,19 +198,19 @@ def sql(self) -> str: @property def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond - estimate = Counter({PlanNodeCategory.JOIN.value: 1}) + stat = Counter({PlanNodeCategory.JOIN.value: 1}) if isinstance(self.join_type, UsingJoin) and self.join_type.using_columns: - estimate += Counter( + stat += Counter( {PlanNodeCategory.COLUMN.value: len(self.join_type.using_columns)} ) - estimate += ( + stat += ( self.join_condition.cumulative_complexity_stat if self.join_condition else Counter() ) - estimate += ( + stat += ( self.match_condition.cumulative_complexity_stat if self.match_condition else Counter() ) - return estimate + return stat diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 66945077459..6f32f43f73e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -85,12 +85,16 @@ def sql(self) -> str: ) return f"{self.pretty_name}({children_sql})" + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.OTHERS + @property def individual_complexity_stat(self) -> Counter[str]: """Returns the individual contribution of the expression node towards the overall compilation complexity of the generated sql. """ - return Counter() + return Counter({self.plan_node_category.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -133,7 +137,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @cached_property def cumulative_complexity_stat(self) -> Counter[str]: - return self.plan.cumulative_complexity_stat + self.individual_complexity_stat + return self.plan.cumulative_complexity_stat class MultipleExpression(Expression): @@ -151,7 +155,6 @@ def cumulative_complexity_stat(self) -> Counter[str]: (expr.cumulative_complexity_stat for expr in self.expressions), Counter(), ) - + self.individual_complexity_stat ) @@ -165,8 +168,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.IN.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.IN @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -208,8 +211,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.COLUMN.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.COLUMN class Star(Expression): @@ -272,8 +275,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.COLUMN.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.COLUMN class Literal(Expression): @@ -299,8 +302,8 @@ def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None: self.datatype = infer_type(value) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.LITERAL.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LITERAL class Interval(Expression): @@ -351,8 +354,8 @@ def __str__(self) -> str: return self.sql @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.OTHERS.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LITERAL class Like(Expression): @@ -365,9 +368,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) @property - def individual_complexity_stat(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern - return Counter({PlanNodeCategory.OTHERS.value: 1}) + return PlanNodeCategory.LOW_IMPACT @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -388,9 +391,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) @property - def individual_complexity_stat(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern - return Counter({PlanNodeCategory.OTHERS.value: 1}) + return PlanNodeCategory.LOW_IMPACT @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -411,9 +414,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) @property - def individual_complexity_stat(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec - return Counter({PlanNodeCategory.OTHERS.value: 1}) + return PlanNodeCategory.LOW_IMPACT @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -430,9 +433,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) @property - def individual_complexity_stat(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field - return Counter({PlanNodeCategory.LITERAL.value: 1}) + return PlanNodeCategory.LITERAL @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -450,9 +453,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) @property - def individual_complexity_stat(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field - return Counter({PlanNodeCategory.LITERAL.value: 1}) + return PlanNodeCategory.LITERAL @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -492,8 +495,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.FUNCTION.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.FUNCTION class WithinGroup(Expression): @@ -507,9 +510,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) @property - def individual_complexity_stat(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) - return Counter({PlanNodeCategory.ORDER_BY.value: 1}) + return PlanNodeCategory.ORDER_BY @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -542,22 +545,22 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*exps) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.CASE_WHEN.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.CASE_WHEN @cached_property def cumulative_complexity_stat(self) -> Counter[str]: - estimate = self.individual_complexity_stat + sum( + stat = self.individual_complexity_stat + sum( ( condition.cumulative_complexity_stat + value.cumulative_complexity_stat for condition, value in self.branches ), Counter(), ) - estimate += ( + stat += ( self.else_value.cumulative_complexity_stat if self.else_value else Counter() ) - return estimate + return stat class SnowflakeUDF(Expression): @@ -580,8 +583,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.FUNCTION.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.FUNCTION @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -602,8 +605,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.FUNCTION.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.FUNCTION @cached_property def cumulative_complexity_stat(self) -> Counter[str]: diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 22ee756b74e..33afcd3cef2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -25,8 +25,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.OTHERS.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT class Cube(GroupingSet): @@ -57,5 +57,5 @@ def cumulative_complexity_stat(self) -> Counter[str]: ) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.OTHERS.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py index cbd8abe72ae..a341daf5a10 100644 --- a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py @@ -23,7 +23,7 @@ class PlanNodeCategory(Enum): """This enum class is used to account for different types of sql text generated by expressions and logical plan nodes. A bottom up aggregation of the number of occurrences of each enum type is - done in Expression and LogicalPlan class to calculate and estimate + done in Expression and LogicalPlan class to calculate and stat of overall query complexity in the context of compiling for the generated sql. """ @@ -43,4 +43,5 @@ class PlanNodeCategory(Enum): COLUMN = "column" # covers all cases where a table column is referred FUNCTION = "function" # cover all snowflake built-in function, table functions and UDXFs IN = "in" + LOW_IMPACT = "low_impact" OTHERS = "others" diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 67220574a61..3288ae19647 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -296,7 +296,7 @@ def num_duplicate_nodes(self) -> int: @property def individual_complexity_stat(self) -> Counter[str]: - """This is the query complexity estimate added by this Selectable node + """This is the query complexity stat added by this Selectable node to the overall query plan. For default case, it is the number of active columns. Specific cases are handled in child classes with additional explanation. @@ -309,14 +309,14 @@ def individual_complexity_stat(self) -> Counter[str]: @property def cumulative_complexity_stat(self) -> Counter[str]: - """This is sum of individual query complexity estimates for all nodes + """This is sum of individual query complexity stats for all nodes within a query plan subtree. """ if self._cumulative_complexity_stat is None: - estimate = self.individual_complexity_stat + stat = self.individual_complexity_stat for node in self.children_plan_nodes: - estimate += node.cumulative_complexity_stat - self._cumulative_complexity_stat = estimate + stat += node.cumulative_complexity_stat + self._cumulative_complexity_stat = stat return self._cumulative_complexity_stat @cumulative_complexity_stat.setter @@ -448,7 +448,7 @@ def individual_complexity_stat(self) -> Counter[str]: # by a SELECT * FROM table(result_scan(query_id)) statement return Counter({PlanNodeCategory.COLUMN.value: 1}) - # no pre-action implies the best estimate we have is of # active columns + # no pre-action implies the best stat we have is of # active columns return Counter( {PlanNodeCategory.COLUMN.value: len(self.column_states.active_columns)} ) @@ -714,9 +714,9 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: @property def individual_complexity_stat(self) -> Counter[str]: - estimate = Counter() + stat = Counter() # projection component - estimate += ( + stat += ( sum( (expr.cumulative_complexity_stat for expr in self.projection), Counter() ) @@ -725,7 +725,7 @@ def individual_complexity_stat(self) -> Counter[str]: ) # filter component - add +1 for WHERE clause and sum of expression complexity for where expression - estimate += ( + stat += ( Counter({PlanNodeCategory.FILTER.value: 1}) + self.where.cumulative_complexity_stat if self.where @@ -733,7 +733,7 @@ def individual_complexity_stat(self) -> Counter[str]: ) # order by component - add complexity for each sort expression - estimate += ( + stat += ( sum( (expr.cumulative_complexity_stat for expr in self.order_by), Counter({PlanNodeCategory.ORDER_BY.value: 1}), @@ -743,13 +743,13 @@ def individual_complexity_stat(self) -> Counter[str]: ) # limit/offset component - estimate += ( - Counter({PlanNodeCategory.OTHERS.value: 1}) if self.limit_ else Counter() + stat += ( + Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) if self.limit_ else Counter() ) - estimate += ( - Counter({PlanNodeCategory.OTHERS.value: 1}) if self.offset else Counter() + stat += ( + Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) if self.offset else Counter() ) - return estimate + return stat def to_subqueryable(self) -> "Selectable": """When this SelectStatement's subquery is not subqueryable (can't be used in `from` clause of the sql), diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 9ad31e32b27..a93f3d68e41 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -361,10 +361,10 @@ def individual_complexity_stat(self) -> Counter[str]: @property def cumulative_complexity_stat(self) -> Counter[str]: if self._cumulative_complexity_stat is None: - estimate = self.individual_complexity_stat + stat = self.individual_complexity_stat for node in self.children_plan_nodes: - estimate += node.cumulative_complexity_stat - self._cumulative_complexity_stat = estimate + stat += node.cumulative_complexity_stat + self._cumulative_complexity_stat = stat return self._cumulative_complexity_stat @cumulative_complexity_stat.setter diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 11bc722f065..c8fc9a23790 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -29,12 +29,16 @@ def __init__(self) -> None: self.children = [] self._cumulative_complexity_stat: Optional[Counter[str]] = None + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.OTHERS + @property def individual_complexity_stat(self) -> Counter[str]: """Returns the individual contribution of the logical plan node towards the overall compilation complexity of the generated sql. """ - return Counter() + return Counter({self.plan_node_category.value: 1}) @property def cumulative_complexity_stat(self) -> Counter[str]: @@ -42,11 +46,11 @@ def cumulative_complexity_stat(self) -> Counter[str]: logical plan node. Statistic of current node is included in the final aggregate. """ if self._cumulative_complexity_stat is None: - estimate = self.individual_complexity_stat + stat = self.individual_complexity_stat for node in self.children: - estimate += node.cumulative_complexity_stat + stat += node.cumulative_complexity_stat - self._cumulative_complexity_stat = estimate + self._cumulative_complexity_stat = stat return self._cumulative_complexity_stat @cumulative_complexity_stat.setter @@ -77,7 +81,7 @@ def individual_complexity_stat(self) -> Counter[str]: PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.LITERAL.value: 3, # step, start, count PlanNodeCategory.COLUMN.value: 1, # id column - PlanNodeCategory.OTHERS.value: 2, # ROW_NUMBER, GENERATOR + PlanNodeCategory.LOW_IMPACT.value: 2, # ROW_NUMBER, GENERATOR } ) @@ -148,16 +152,14 @@ def __init__( @property def individual_complexity_stat(self) -> Counter[str]: - # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (child) - estimate = Counter( - {PlanNodeCategory.OTHERS.value: 1, PlanNodeCategory.COLUMN.value: 1} - ) - estimate += ( + # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (query) + stat = Counter({PlanNodeCategory.COLUMN.value: 1}) + stat += ( Counter({PlanNodeCategory.COLUMN.value: len(self.column_names)}) if self.column_names else Counter() ) - estimate += ( + stat += ( sum( (expr.cumulative_complexity_stat for expr in self.clustering_exprs), Counter(), @@ -165,7 +167,7 @@ def individual_complexity_stat(self) -> Counter[str]: if self.clustering_exprs else Counter() ) - return estimate + return stat class Limit(LogicalPlan): @@ -182,7 +184,7 @@ def __init__( def individual_complexity_stat(self) -> Counter[str]: # for limit and offset return ( - Counter({PlanNodeCategory.OTHERS.value: 2}) + Counter({PlanNodeCategory.LOW_IMPACT.value: 2}) + self.limit_expr.cumulative_complexity_stat + self.offset_expr.cumulative_complexity_stat ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index d55ebf5ed51..8d5134e5e77 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -39,8 +39,8 @@ def __init__( def cumulative_complexity_stat(self) -> Counter[str]: if not self.over: return Counter() - estimate = Counter({PlanNodeCategory.WINDOW.value: 1}) - estimate += ( + stat = Counter({PlanNodeCategory.WINDOW.value: 1}) + stat += ( sum( (expr.cumulative_complexity_stat for expr in self.partition_spec), Counter({PlanNodeCategory.PARTITION_BY.value: 1}), @@ -48,7 +48,7 @@ def cumulative_complexity_stat(self) -> Counter[str]: if self.partition_spec else Counter() ) - estimate += ( + stat += ( sum( (expr.cumulative_complexity_stat for expr in self.order_spec), Counter({PlanNodeCategory.ORDER_BY.value: 1}), @@ -56,7 +56,7 @@ def cumulative_complexity_stat(self) -> Counter[str]: if self.order_spec else Counter() ) - return estimate + return stat class TableFunctionExpression(Expression): @@ -74,17 +74,8 @@ def __init__( self.api_call_source = api_call_source @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.FUNCTION.value: 1}) - - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: - return ( - self.partition_spec.cumulative_complexity_stat - + self.individual_complexity_stat - if self.partition_spec - else self.individual_complexity_stat - ) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.FUNCTION class FlattenFunction(TableFunctionExpression): @@ -115,16 +106,16 @@ def __init__( @cached_property def cumulative_complexity_stat(self) -> Counter[str]: - estimate = sum( + stat = sum( (arg.cumulative_complexity_stat for arg in self.args), self.individual_complexity_stat, ) - estimate += ( + stat += ( self.partition_spec.cumulative_complexity_stat if self.partition_spec else Counter() ) - return estimate + return stat class NamedArgumentsTableFunction(TableFunctionExpression): @@ -139,16 +130,16 @@ def __init__( @cached_property def cumulative_complexity_stat(self) -> Counter[str]: - estimate = sum( + stat = sum( (arg.cumulative_complexity_stat for arg in self.args.values()), self.individual_complexity_stat, ) - estimate += ( + stat += ( self.partition_spec.cumulative_complexity_stat if self.partition_spec else Counter() ) - return estimate + return stat class GeneratorTableFunction(TableFunctionExpression): @@ -159,17 +150,17 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: @cached_property def cumulative_complexity_stat(self) -> Counter[str]: - estimate = sum( + stat = sum( (arg.cumulative_complexity_stat for arg in self.args.values()), self.individual_complexity_stat, ) - estimate += ( + stat += ( self.partition_spec.cumulative_complexity_stat if self.partition_spec else Counter() ) - estimate += Counter({PlanNodeCategory.COLUMN.value: len(self.operators)}) - return estimate + stat += Counter({PlanNodeCategory.COLUMN.value: len(self.operators)}) + return stat class TableFunctionRelation(LogicalPlan): diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index 26e377e0c0a..c9244483ed4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -22,17 +22,17 @@ def __init__(self, condition: Optional[Expression]) -> None: self.condition = condition @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.OTHERS.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT @cached_property def cumulative_complexity_stat(self) -> Counter[str]: # WHEN MATCHED [AND condition] THEN DEL - estimate = self.individual_complexity_stat - estimate += ( + stat = self.individual_complexity_stat + stat += ( self.condition.cumulative_complexity_stat if self.condition else Counter() ) - return estimate + return stat class UpdateMergeExpression(MergeExpression): @@ -45,11 +45,11 @@ def __init__( @cached_property def cumulative_complexity_stat(self) -> Counter[str]: # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) - estimate = self.individual_complexity_stat - estimate += ( + stat = self.individual_complexity_stat + stat += ( self.condition.cumulative_complexity_stat if self.condition else Counter() ) - estimate += sum( + stat += sum( ( key_expr.cumulative_complexity_stat + val_expr.cumulative_complexity_stat @@ -57,7 +57,7 @@ def cumulative_complexity_stat(self) -> Counter[str]: ), Counter(), ) - return estimate + return stat class DeleteMergeExpression(MergeExpression): @@ -78,17 +78,17 @@ def __init__( @cached_property def cumulative_complexity_stat(self) -> Counter[str]: # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) - estimate = self.individual_complexity_stat - estimate += ( + stat = self.individual_complexity_stat + stat += ( self.condition.cumulative_complexity_stat if self.condition else Counter() ) - estimate += sum( + stat += sum( (key.cumulative_complexity_stat for key in self.keys), Counter() ) - estimate += sum( + stat += sum( (val.cumulative_complexity_stat for val in self.values), Counter() ) - return estimate + return stat class TableUpdate(LogicalPlan): @@ -109,17 +109,17 @@ def __init__( @property def individual_complexity_stat(self) -> Counter[str]: # UPDATE table_name SET COMMA.join(k, v in assignments) [source_data] [WHERE condition] - estimate = sum( + stat = sum( ( k.cumulative_complexity_stat + v.cumulative_complexity_stat for k, v in self.assignments.items() ), Counter(), ) - estimate += ( + stat += ( self.condition.cumulative_complexity_stat if self.condition else Counter() ) - return estimate + return stat class TableDelete(LogicalPlan): diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index 00948f9e302..04ccbdc3f0c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -38,8 +38,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.OTHERS.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT class Cast(UnaryExpression): @@ -89,9 +89,9 @@ def __str__(self): return f"{self.child} {self.sql_operator} {self.name}" @property - def individual_complexity_stat(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: # child AS name - return Counter({PlanNodeCategory.COLUMN.value: 1}) + return PlanNodeCategory.COLUMN class UnresolvedAlias(UnaryExpression, NamedExpression): @@ -103,5 +103,5 @@ def __init__(self, child: Expression) -> None: self.name = child.sql @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter() + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.OTHERS diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index fedb03d9c6f..3cc5d93e0c9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -76,18 +76,18 @@ def __init__( @property def individual_complexity_stat(self) -> Counter[str]: - estimate = Counter() + stat = Counter() if self.grouping_expressions: # GROUP BY grouping_exprs - estimate += Counter({PlanNodeCategory.GROUP_BY.value: 1}) + sum( + stat += Counter({PlanNodeCategory.GROUP_BY.value: 1}) + sum( (expr.cumulative_complexity_stat for expr in self.grouping_expressions), Counter(), ) else: # LIMIT 1 - estimate += Counter({PlanNodeCategory.OTHERS.value: 1}) + stat += Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) - estimate += sum( + stat += sum( ( getattr( expr, @@ -98,7 +98,7 @@ def individual_complexity_stat(self) -> Counter[str]: ), Counter(), ) - return estimate + return stat class Pivot(UnaryNode): @@ -120,38 +120,38 @@ def __init__( @property def individual_complexity_stat(self) -> Counter[str]: - estimate = Counter() - # child estimate adjustment if grouping cols + stat = Counter() + # child stat adjustment if grouping cols if self.grouping_columns and self.aggregates and self.aggregates[0].children: # for additional projecting cols when grouping cols is not empty - estimate += sum( + stat += sum( (col.cumulative_complexity_stat for col in self.grouping_columns), Counter(), ) - estimate += self.pivot_column.cumulative_complexity_stat - estimate += self.aggregates[0].children[0].cumulative_complexity_stat + stat += self.pivot_column.cumulative_complexity_stat + stat += self.aggregates[0].children[0].cumulative_complexity_stat # pivot col if isinstance(self.pivot_values, ScalarSubquery): - estimate += self.pivot_values.cumulative_complexity_stat + stat += self.pivot_values.cumulative_complexity_stat elif isinstance(self.pivot_values, List): - estimate += sum( + stat += sum( (val.cumulative_complexity_stat for val in self.pivot_values), Counter() ) else: # if pivot values is None, then we add OTHERS for ANY - estimate += Counter({PlanNodeCategory.OTHERS.value: 1}) + stat += Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) - # aggregate estimate - estimate += sum( + # aggregate stat + stat += sum( (expr.cumulative_complexity_stat for expr in self.aggregates), Counter() ) # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) - estimate += Counter( + stat += Counter( {PlanNodeCategory.COLUMN.value: 2, PlanNodeCategory.PIVOT.value: 1} ) - return estimate + return stat class Unpivot(UnaryNode): @@ -170,13 +170,13 @@ def __init__( @property def individual_complexity_stat(self) -> Counter[str]: # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) - estimate = Counter( + stat = Counter( {PlanNodeCategory.UNPIVOT.value: 1, PlanNodeCategory.COLUMN.value: 3} ) - estimate += sum( + stat += sum( (expr.cumulative_complexity_stat for expr in self.column_list), Counter() ) - return estimate + return stat class Rename(UnaryNode): @@ -194,7 +194,7 @@ def individual_complexity_stat(self) -> Counter[str]: return Counter( { PlanNodeCategory.COLUMN.value: 1 + 2 * len(self.column_map), - PlanNodeCategory.OTHERS.value: 1 + len(self.column_map), + PlanNodeCategory.LOW_IMPACT.value: 1 + len(self.column_map), } ) diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index e62e7776ad6..ff5aec38a5a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -24,7 +24,7 @@ def __init__(self) -> None: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.OTHERS.value: 1}) + return Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) class UnboundedPreceding(SpecialFrameBoundary): @@ -74,7 +74,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.OTHERS.value: 1}) + return Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -105,18 +105,18 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: - estimate = Counter() - estimate += ( + stat = Counter() + stat += ( Counter({PlanNodeCategory.PARTITION_BY.value: 1}) if self.partition_spec else Counter() ) - estimate += ( + stat += ( Counter({PlanNodeCategory.ORDER_BY.value: 1}) if self.order_spec else Counter() ) - return estimate + return stat @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -181,29 +181,29 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_complexity_stat(self) -> Counter[str]: # for func_name - estimate = Counter({PlanNodeCategory.FUNCTION.value: 1}) + stat = Counter({PlanNodeCategory.FUNCTION.value: 1}) # for offset - estimate += ( + stat += ( Counter({PlanNodeCategory.LITERAL.value: 1}) if self.offset else Counter() ) # for ignore nulls - estimate += ( - Counter({PlanNodeCategory.OTHERS.value: 1}) + stat += ( + Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) if self.ignore_nulls else Counter() ) - return estimate + return stat @cached_property def cumulative_complexity_stat(self) -> Counter[str]: # func_name (expr [, offset] [, default]) [IGNORE NULLS] - estimate = ( + stat = ( self.individual_complexity_stat + self.expr.cumulative_complexity_stat ) - estimate += ( + stat += ( self.default.cumulative_complexity_stat if self.default else Counter() ) - return estimate + return stat class Lag(RankRelatedFunctionExpression): diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 801af7d5524..48cc53db23f 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -68,7 +68,7 @@ class TelemetryField(Enum): # dataframe query stats QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" - QUERY_PLAN_COMPLEXITY_STAT = "query_plan_complexity_stat" + QUERY_PLAN_STAT = "query_plan_stat" # These DataFrame APIs call other DataFrame APIs @@ -161,7 +161,7 @@ def wrap(*args, **kwargs): api_calls[0][ TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value ] = plan.num_duplicate_nodes - api_calls[0][TelemetryField.QUERY_PLAN_COMPLEXITY_STAT.value] = dict( + api_calls[0][TelemetryField.QUERY_PLAN_STAT.value] = dict( plan.cumulative_complexity_stat ) except Exception: diff --git a/tests/integ/test_materialization_suite.py b/tests/integ/test_query_plan_analysis.py similarity index 92% rename from tests/integ/test_materialization_suite.py rename to tests/integ/test_query_plan_analysis.py index 766f213170d..4ab2d268bad 100644 --- a/tests/integ/test_materialization_suite.py +++ b/tests/integ/test_query_plan_analysis.py @@ -89,7 +89,7 @@ def test_range_statement(session: Session): { PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.LITERAL.value: 3, - PlanNodeCategory.OTHERS.value: 2, + PlanNodeCategory.LOW_IMPACT.value: 2, PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.WINDOW.value: 1, }, @@ -114,7 +114,7 @@ def test_generator_table_function(session: Session): assert_df_subtree_query_complexity( df2, get_cumulative_complexity_stat(df1) - + Counter({PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.COLUMN.value: 1}), + + Counter({PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}), ) @@ -126,11 +126,12 @@ def test_join_table_function(session: Session): assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN.value: 2}) split_to_table = table_function("split_to_table") - df2 = df1.select(split_to_table(col("addresses"), lit(" "))) + # SELECT "SEQ", "INDEX", "VALUE" FROM ( # SELECT T_RIGHT."SEQ", T_RIGHT."INDEX", T_RIGHT."VALUE" FROM # (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT # JOIN TABLE (split_to_table("ADDRESSES", ' ') ) AS T_RIGHT) + df2 = df1.select(split_to_table(col("addresses"), lit(" "))) assert_df_subtree_query_complexity( df2, { @@ -141,6 +142,15 @@ def test_join_table_function(session: Session): }, ) + # SELECT T_LEFT.*, T_RIGHT.* FROM (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT JOIN TABLE (split_to_table("ADDRESS", ' ') ) AS T_RIGHT + df3 = df1.join_table_function(split_to_table(col("address"), lit(" "))) + assert_df_subtree_query_complexity(df3, { + PlanNodeCategory.COLUMN.value: 5, + PlanNodeCategory.JOIN.value: 1, + PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.LITERAL.value: 1, + }) + @pytest.mark.parametrize( "set_operator", [SET_UNION, SET_UNION_ALL, SET_EXCEPT, SET_INTERSECT] @@ -175,7 +185,7 @@ def test_agg(session: Session, sample_table: str): df1, { PlanNodeCategory.COLUMN.value: 3, - PlanNodeCategory.OTHERS.value: 1, + PlanNodeCategory.LOW_IMPACT.value: 1, PlanNodeCategory.FUNCTION.value: 1, }, ) @@ -184,7 +194,7 @@ def test_agg(session: Session, sample_table: str): df2, { PlanNodeCategory.COLUMN.value: 3, - PlanNodeCategory.OTHERS.value: 2, + PlanNodeCategory.LOW_IMPACT.value: 2, PlanNodeCategory.FUNCTION.value: 1, PlanNodeCategory.LITERAL.value: 1, }, @@ -194,7 +204,7 @@ def test_agg(session: Session, sample_table: str): df3, { PlanNodeCategory.COLUMN.value: 5, - PlanNodeCategory.OTHERS.value: 2, + PlanNodeCategory.LOW_IMPACT.value: 2, PlanNodeCategory.FUNCTION.value: 2, PlanNodeCategory.LITERAL.value: 1, }, @@ -238,7 +248,8 @@ def test_window_function(session: Session): PlanNodeCategory.FUNCTION.value: 1, PlanNodeCategory.COLUMN.value: 5, PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.OTHERS.value: 2, + PlanNodeCategory.LOW_IMPACT.value: 2, + PlanNodeCategory.OTHERS.value: 1, } ), ) @@ -255,7 +266,8 @@ def test_window_function(session: Session): PlanNodeCategory.WINDOW.value: 1, PlanNodeCategory.FUNCTION.value: 1, PlanNodeCategory.COLUMN.value: 3, - PlanNodeCategory.OTHERS.value: 3, + PlanNodeCategory.LOW_IMPACT.value: 3, + PlanNodeCategory.OTHERS.value: 1, } ), ) @@ -292,7 +304,7 @@ def test_join_statement(session: Session, sample_table: str): assert_df_subtree_query_complexity( df4, get_cumulative_complexity_stat(df3) - + Counter({PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.OTHERS.value: 3}), + + Counter({PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.LOW_IMPACT.value: 3}), ) df5 = df1.join(df2, using_columns=["a", "b"]) @@ -417,14 +429,14 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df4, get_cumulative_complexity_stat(df3) - + Counter({PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.ORDER_BY.value: 1}), + + Counter({PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.OTHERS.value: 1}), ) # for additional ,"C" ASC NULLS FIRST df5 = df4.sort(col("c").desc()) assert_df_subtree_query_complexity( df5, - get_cumulative_complexity_stat(df4) + Counter({PlanNodeCategory.COLUMN.value: 1}), + get_cumulative_complexity_stat(df4) + Counter({PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}), ) # add filter @@ -438,7 +450,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl PlanNodeCategory.FILTER.value: 1, PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.OTHERS.value: 1, + PlanNodeCategory.LOW_IMPACT.value: 1, } ), ) @@ -452,7 +464,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl { PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.OTHERS.value: 2, + PlanNodeCategory.LOW_IMPACT.value: 2, } ), ) @@ -482,7 +494,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df10, get_cumulative_complexity_stat(df9) - + Counter({PlanNodeCategory.OTHERS.value: 1}), + + Counter({PlanNodeCategory.LOW_IMPACT.value: 1}), ) # for offset @@ -490,5 +502,5 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df11, get_cumulative_complexity_stat(df9) - + Counter({PlanNodeCategory.OTHERS.value: 2}), + + Counter({PlanNodeCategory.LOW_IMPACT.value: 2}), ) diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index bd00cecb0bb..71ce5084144 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -593,7 +593,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_stat": { + "query_plan_stat": { "filter": 1, "low_impact": 5, "column": 3, @@ -614,7 +614,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_stat": { + "query_plan_stat": { "filter": 1, "low_impact": 5, "column": 3, @@ -635,7 +635,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_stat": { + "query_plan_stat": { "filter": 1, "low_impact": 5, "column": 3, @@ -656,7 +656,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_stat": { + "query_plan_stat": { "filter": 1, "low_impact": 5, "column": 3, @@ -677,7 +677,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_stat": { + "query_plan_stat": { "filter": 1, "low_impact": 5, "column": 3, @@ -821,7 +821,7 @@ def test_dataframe_stat_functions_api_calls(session): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_stat": {"group_by": 1, "column": 6, "literal": 48}, + "query_plan_stat": {"group_by": 1, "column": 6, "literal": 48}, }, { "name": "DataFrameStatFunctions.crosstab", @@ -839,7 +839,7 @@ def test_dataframe_stat_functions_api_calls(session): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, - "query_plan_complexity_stat": {"group_by": 1, "column": 6, "literal": 48}, + "query_plan_stat": {"group_by": 1, "column": 6, "literal": 48}, } ] From 2b2f8a58de2185c5e8259a7cabb0f87542405270 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 11:35:57 -0700 Subject: [PATCH 23/37] fix lint and type hints --- .../_internal/analyzer/binary_expression.py | 1 - .../_internal/analyzer/binary_plan_node.py | 4 +-- .../snowpark/_internal/analyzer/expression.py | 10 +++--- .../_internal/analyzer/grouping_set.py | 2 +- .../analyzer/query_plan_analysis_utils.py | 4 ++- .../_internal/analyzer/select_statement.py | 14 +++++--- .../_internal/analyzer/snowflake_plan_node.py | 2 +- .../_internal/analyzer/table_function.py | 2 +- .../analyzer/table_merge_expression.py | 10 ++---- .../_internal/analyzer/unary_expression.py | 1 - .../_internal/analyzer/unary_plan_node.py | 2 +- .../_internal/analyzer/window_expression.py | 10 ++---- tests/integ/test_query_plan_analysis.py | 35 ++++++++++++++----- 13 files changed, 56 insertions(+), 41 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 41bf92633f9..3ed969caada 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -10,7 +10,6 @@ ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - Counter, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index b49410dbd66..17d8f34464c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -6,8 +6,8 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages @@ -74,7 +74,7 @@ def __init__(self, left: LogicalPlan, right: LogicalPlan) -> None: class SetOperation(BinaryNode): @property - def plan_node_category(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: # (left) operator (right) return PlanNodeCategory.SET_OPERATION diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 6f32f43f73e..54d72a617d2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -9,8 +9,8 @@ import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) if TYPE_CHECKING: @@ -150,11 +150,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @cached_property def cumulative_complexity_stat(self) -> Counter[str]: - return ( - sum( - (expr.cumulative_complexity_stat for expr in self.expressions), - Counter(), - ) + return sum( + (expr.cumulative_complexity_stat for expr in self.expressions), + Counter(), ) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 33afcd3cef2..c3e4e7f026c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -10,8 +10,8 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py index a341daf5a10..55fafb53ef7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py @@ -41,7 +41,9 @@ class PlanNodeCategory(Enum): CASE_WHEN = "case_when" LITERAL = "literal" # cover all literals like numbers, constant strings, etc COLUMN = "column" # covers all cases where a table column is referred - FUNCTION = "function" # cover all snowflake built-in function, table functions and UDXFs + FUNCTION = ( + "function" # cover all snowflake built-in function, table functions and UDXFs + ) IN = "in" LOW_IMPACT = "low_impact" OTHERS = "others" diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 3288ae19647..1d64b474a37 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -23,8 +23,8 @@ import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.cte_utils import encode_id from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) from snowflake.snowpark._internal.analyzer.table_function import ( TableFunctionExpression, @@ -744,10 +744,14 @@ def individual_complexity_stat(self) -> Counter[str]: # limit/offset component stat += ( - Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) if self.limit_ else Counter() + Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) + if self.limit_ + else Counter() ) stat += ( - Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) if self.offset else Counter() + Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) + if self.offset + else Counter() ) return stat @@ -1135,7 +1139,9 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: @property def individual_complexity_stat(self) -> Counter[str]: # we add #set_operands - 1 additional operators in sql query - return Counter({PlanNodeCategory.SET_OPERATION.value: len(self.set_operands) - 1}) + return Counter( + {PlanNodeCategory.SET_OPERATION.value: len(self.set_operands) - 1} + ) class DeriveColumnDependencyError(Exception): diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index c8fc9a23790..0fcc6a3d6d6 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -9,8 +9,8 @@ from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 8d5134e5e77..24f47900f4b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -8,8 +8,8 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index c9244483ed4..94adc5c0159 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -7,8 +7,8 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( LogicalPlan, @@ -82,12 +82,8 @@ def cumulative_complexity_stat(self) -> Counter[str]: stat += ( self.condition.cumulative_complexity_stat if self.condition else Counter() ) - stat += sum( - (key.cumulative_complexity_stat for key in self.keys), Counter() - ) - stat += sum( - (val.cumulative_complexity_stat for val in self.values), Counter() - ) + stat += sum((key.cumulative_complexity_stat for key in self.keys), Counter()) + stat += sum((val.cumulative_complexity_stat for val in self.values), Counter()) return stat diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index 04ccbdc3f0c..bf55421e6ee 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -11,7 +11,6 @@ ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - Counter, ) from snowflake.snowpark.types import DataType diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 3cc5d93e0c9..fb0016d1918 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -10,8 +10,8 @@ ScalarSubquery, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index ff5aec38a5a..b985131407f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -10,8 +10,8 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -197,12 +197,8 @@ def individual_complexity_stat(self) -> Counter[str]: @cached_property def cumulative_complexity_stat(self) -> Counter[str]: # func_name (expr [, offset] [, default]) [IGNORE NULLS] - stat = ( - self.individual_complexity_stat + self.expr.cumulative_complexity_stat - ) - stat += ( - self.default.cumulative_complexity_stat if self.default else Counter() - ) + stat = self.individual_complexity_stat + self.expr.cumulative_complexity_stat + stat += self.default.cumulative_complexity_stat if self.default else Counter() return stat diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 4ab2d268bad..1b71132ec22 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -6,8 +6,8 @@ import pytest from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, Counter, + PlanNodeCategory, ) from snowflake.snowpark._internal.analyzer.select_statement import ( SET_EXCEPT, @@ -114,7 +114,13 @@ def test_generator_table_function(session: Session): assert_df_subtree_query_complexity( df2, get_cumulative_complexity_stat(df1) - + Counter({PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}), + + Counter( + { + PlanNodeCategory.ORDER_BY.value: 1, + PlanNodeCategory.COLUMN.value: 1, + PlanNodeCategory.OTHERS.value: 1, + } + ), ) @@ -144,12 +150,15 @@ def test_join_table_function(session: Session): # SELECT T_LEFT.*, T_RIGHT.* FROM (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT JOIN TABLE (split_to_table("ADDRESS", ' ') ) AS T_RIGHT df3 = df1.join_table_function(split_to_table(col("address"), lit(" "))) - assert_df_subtree_query_complexity(df3, { + assert_df_subtree_query_complexity( + df3, + { PlanNodeCategory.COLUMN.value: 5, PlanNodeCategory.JOIN.value: 1, PlanNodeCategory.FUNCTION.value: 1, PlanNodeCategory.LITERAL.value: 1, - }) + }, + ) @pytest.mark.parametrize( @@ -304,14 +313,17 @@ def test_join_statement(session: Session, sample_table: str): assert_df_subtree_query_complexity( df4, get_cumulative_complexity_stat(df3) - + Counter({PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.LOW_IMPACT.value: 3}), + + Counter( + {PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.LOW_IMPACT.value: 3} + ), ) df5 = df1.join(df2, using_columns=["a", "b"]) # SELECT * FROM ( (ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT USING (a, b)) assert_df_subtree_query_complexity( df5, - get_cumulative_complexity_stat(df3) + Counter({PlanNodeCategory.COLUMN.value: 2}), + get_cumulative_complexity_stat(df3) + + Counter({PlanNodeCategory.COLUMN.value: 2}), ) @@ -429,14 +441,21 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df4, get_cumulative_complexity_stat(df3) - + Counter({PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.OTHERS.value: 1}), + + Counter( + { + PlanNodeCategory.COLUMN.value: 1, + PlanNodeCategory.ORDER_BY.value: 1, + PlanNodeCategory.OTHERS.value: 1, + } + ), ) # for additional ,"C" ASC NULLS FIRST df5 = df4.sort(col("c").desc()) assert_df_subtree_query_complexity( df5, - get_cumulative_complexity_stat(df4) + Counter({PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}), + get_cumulative_complexity_stat(df4) + + Counter({PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}), ) # add filter From 7f636e6b8df74ea9d24b2784a1f44ba06cef042e Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 11:42:54 -0700 Subject: [PATCH 24/37] fix classification for Interval expression --- src/snowflake/snowpark/_internal/analyzer/expression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 54d72a617d2..6a19a978782 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -353,7 +353,7 @@ def __str__(self) -> str: @property def plan_node_category(self) -> PlanNodeCategory: - return PlanNodeCategory.LITERAL + return PlanNodeCategory.LOW_IMPACT class Like(Expression): From e40129d1ff5cadc6b4b273e9691568e22cbc72f8 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 17:57:10 -0700 Subject: [PATCH 25/37] added some unit tests --- .../_internal/analyzer/select_statement.py | 24 +-- .../_internal/analyzer/window_expression.py | 12 +- tests/integ/test_query_plan_analysis.py | 27 ++- tests/unit/conftest.py | 25 +++ tests/unit/test_query_plan_analysis.py | 165 ++++++++++++++++++ 5 files changed, 222 insertions(+), 31 deletions(-) create mode 100644 tests/unit/test_query_plan_analysis.py diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 1d64b474a37..a5ff71222a9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -294,24 +294,8 @@ def plan_height(self) -> int: def num_duplicate_nodes(self) -> int: return self.snowflake_plan.num_duplicate_nodes - @property - def individual_complexity_stat(self) -> Counter[str]: - """This is the query complexity stat added by this Selectable node - to the overall query plan. For default case, it is the number of active - columns. Specific cases are handled in child classes with additional - explanation. - """ - if isinstance(self.snowflake_plan.source_plan, Selectable): - return Counter( - {PlanNodeCategory.COLUMN.value: len(self.column_states.active_columns)} - ) - return self.snowflake_plan.individual_complexity_stat - @property def cumulative_complexity_stat(self) -> Counter[str]: - """This is sum of individual query complexity stats for all nodes - within a query plan subtree. - """ if self._cumulative_complexity_stat is None: stat = self.individual_complexity_stat for node in self.children_plan_nodes: @@ -382,9 +366,9 @@ def schema_query(self) -> str: return self.sql_query @property - def individual_complexity_stat(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: # SELECT * FROM entity - return Counter({PlanNodeCategory.COLUMN.value: 1}) + return PlanNodeCategory.COLUMN @property def query_params(self) -> Optional[Sequence[Any]]: @@ -1057,6 +1041,10 @@ def schema_query(self) -> str: def query_params(self) -> Optional[Sequence[Any]]: return self.snowflake_plan.queries[-1].params + @property + def individual_complexity_stat(self) -> Counter[str]: + return self.snowflake_plan.individual_complexity_stat + class SetOperand: def __init__(self, selectable: Selectable, operator: Optional[str] = None) -> None: diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index b985131407f..d58a40eaaa0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -23,8 +23,8 @@ def __init__(self) -> None: super().__init__() @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT class UnboundedPreceding(SpecialFrameBoundary): @@ -73,8 +73,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT @cached_property def cumulative_complexity_stat(self) -> Counter[str]: @@ -146,8 +146,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) @property - def individual_complexity_stat(self) -> Counter[str]: - return Counter({PlanNodeCategory.WINDOW.value: 1}) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.WINDOW @cached_property def cumulative_complexity_stat(self) -> Counter[str]: diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 1b71132ec22..41752d5e994 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -76,8 +76,8 @@ def test_create_dataframe_from_values(session: Session): def test_session_table(session: Session, sample_table: str): - df = session.table(sample_table) # select * from sample_table + df = session.table(sample_table) assert_df_subtree_query_complexity(df, {PlanNodeCategory.COLUMN.value: 1}) @@ -148,15 +148,24 @@ def test_join_table_function(session: Session): }, ) - # SELECT T_LEFT.*, T_RIGHT.* FROM (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT JOIN TABLE (split_to_table("ADDRESS", ' ') ) AS T_RIGHT - df3 = df1.join_table_function(split_to_table(col("address"), lit(" "))) + # SELECT T_LEFT.*, T_RIGHT.* FROM (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT + # JOIN TABLE (split_to_table("ADDRESS", ' ') OVER (PARTITION BY "LAST_NAME" ORDER BY "FIRST_NAME" ASC NULLS FIRST)) AS T_RIGHT + df3 = df1.join_table_function( + split_to_table(col("address"), lit(" ")).over( + partition_by="last_name", order_by="first_name" + ) + ) assert_df_subtree_query_complexity( df3, { - PlanNodeCategory.COLUMN.value: 5, + PlanNodeCategory.COLUMN.value: 7, PlanNodeCategory.JOIN.value: 1, PlanNodeCategory.FUNCTION.value: 1, PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.PARTITION_BY.value: 1, + PlanNodeCategory.ORDER_BY.value: 1, + PlanNodeCategory.WINDOW.value: 1, + PlanNodeCategory.OTHERS.value: 1, }, ) @@ -421,13 +430,17 @@ def test_sample(session: Session, sample_table): def test_select_statement_with_multiple_operations(session: Session, sample_table: str): - df1 = session.table(sample_table) + df = session.table(sample_table) # add select - # SELECT "A", "B", "C" FROM sample_table - # note that column stat is 4 even though selected columns is 3. This is because we count 1 column + # SELECT "A", "B", "C", "D" FROM sample_table + # note that column stat is 5 even though selected columns is 4. This is because we count 1 column # from select * from sample_table which is flattened out. This is a known limitation but is okay # since we are not off my much + df1 = df.select(df["*"]) + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN.value: 5}) + + # SELECT "A", "B", "C" FROM sample_table df2 = df1.select("a", "b", "c") assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN.value: 4}) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 6419b77d42f..aab964bad09 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,7 +8,10 @@ from snowflake.connector import SnowflakeConnection from snowflake.connector.cursor import SnowflakeCursor +from snowflake.snowpark._internal.analyzer.analyzer import Analyzer +from snowflake.snowpark._internal.analyzer.snowflake_plan import Query from snowflake.snowpark._internal.server_connection import ServerConnection +from snowflake.snowpark.session import Session @pytest.fixture @@ -21,3 +24,25 @@ def mock_server_connection() -> ServerConnection: ) fake_snowflake_connection.is_closed.return_value = False return ServerConnection({}, fake_snowflake_connection) + + +@pytest.fixture(scope="module") +def mock_analyzer() -> Analyzer: + fake_analyzer = mock.create_autospec(Analyzer) + return fake_analyzer + + +@pytest.fixture(scope="module") +def mock_session(mock_analyzer) -> Session: + fake_session = mock.create_autospec(Session) + fake_session._analyzer = mock_analyzer + mock_analyzer.session = fake_session + return fake_session + + +@pytest.fixture(scope="module") +def mock_query(): + fake_query = mock.create_autospec(Query) + fake_query.sql = "dummy sql" + fake_query.params = "dummy params" + return fake_query diff --git a/tests/unit/test_query_plan_analysis.py b/tests/unit/test_query_plan_analysis.py new file mode 100644 index 00000000000..669db905155 --- /dev/null +++ b/tests/unit/test_query_plan_analysis.py @@ -0,0 +1,165 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from collections import Counter +from unittest import mock + +import pytest + +from snowflake.snowpark._internal.analyzer.expression import ( + Attribute, + Expression, + NamedExpression, +) +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, +) +from snowflake.snowpark._internal.analyzer.select_statement import ( + Selectable, + SelectableEntity, + SelectSnowflakePlan, + SelectSQL, + SelectStatement, + SelectTableFunction, + SetStatement, +) +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan +from snowflake.snowpark._internal.analyzer.unary_plan_node import Project +from snowflake.snowpark.types import IntegerType + + +@pytest.mark.parametrize("node_type", [LogicalPlan, SnowflakePlan, Selectable]) +def test_assign_custom_cumulative_complexity_stat( + mock_session, mock_analyzer, mock_query, node_type +): + def get_node_for_type(node_type): + if node_type == LogicalPlan: + return LogicalPlan() + if node_type == SnowflakePlan: + return SnowflakePlan( + [mock_query], "", source_plan=LogicalPlan(), session=mock_session + ) + return SelectSnowflakePlan( + SnowflakePlan( + [mock_query], "", source_plan=LogicalPlan(), session=mock_session + ), + analyzer=mock_analyzer, + ) + + def set_children(node, node_type, children): + if node_type == LogicalPlan: + node.children = children + elif node_type == SnowflakePlan: + node.source_plan.children = children + else: + node.snowflake_plan.source_plan.children = children + + nodes = [get_node_for_type(node_type) for _ in range(7)] + + """ + o o + / \\ / \ + o o x o + /|\ + o o o -> + | + o + """ + set_children(nodes[0], node_type, [nodes[1], nodes[2]]) + set_children(nodes[1], node_type, [nodes[3], nodes[4], nodes[5]]) + set_children(nodes[2], node_type, []) + set_children(nodes[3], node_type, []) + set_children(nodes[4], node_type, [nodes[6]]) + set_children(nodes[5], node_type, []) + set_children(nodes[6], node_type, []) + + assert nodes[0].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 7} + assert nodes[1].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 5} + assert nodes[2].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[3].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[4].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 2} + assert nodes[5].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[6].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} + + nodes[1].cumulative_complexity_stat = Counter({PlanNodeCategory.COLUMN.value: 1}) + + # assert that only value that is reset is changed + assert nodes[0].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 7} + assert nodes[1].cumulative_complexity_stat == {PlanNodeCategory.COLUMN.value: 1} + assert nodes[2].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} + + +def test_selectable_entity_individual_complexity_stat(mock_analyzer): + plan_node = SelectableEntity(entity_name="dummy entity", analyzer=mock_analyzer) + assert plan_node.individual_complexity_stat == {PlanNodeCategory.COLUMN.value: 1} + + +def test_select_sql_individual_complexity_stat(mock_session, mock_analyzer): + plan_node = SelectSQL( + "non-select statement", convert_to_select=True, analyzer=mock_analyzer + ) + assert plan_node.individual_complexity_stat == {PlanNodeCategory.COLUMN.value: 1} + + def mocked_get_result_attributes(sql): + return [Attribute("A", IntegerType()), Attribute("B", IntegerType())] + + def mocked_analyze( + attr: Attribute, df_aliased_col_name_to_real_col_name, parse_local_name + ): + return attr.name + + with mock.patch.object( + mock_session, "_get_result_attributes", side_effect=mocked_get_result_attributes + ): + with mock.patch.object(mock_analyzer, "analyze", side_effect=mocked_analyze): + plan_node = SelectSQL("select 1 as A, 2 as B", analyzer=mock_analyzer) + assert plan_node.individual_complexity_stat == { + PlanNodeCategory.COLUMN.value: 2 + } + + +def test_select_snowflake_plan_individual_complexity_stat( + mock_session, mock_analyzer, mock_query +): + source_plan = Project([NamedExpression(), NamedExpression()], LogicalPlan()) + snowflake_plan = SnowflakePlan( + [mock_query], "", source_plan=source_plan, session=mock_session + ) + plan_node = SelectSnowflakePlan(snowflake_plan, analyzer=mock_analyzer) + assert plan_node.individual_complexity_stat == {PlanNodeCategory.COLUMN.value: 2} + + +@pytest.mark.parametrize( + "attribute,value,expected_stat", + [ + ("projection", [NamedExpression()], {}), + ("order_by", [Expression()], {}), + ("where", Expression(), {}), + ("limit_", 10, {}), + ("offset", 2, {}), + ], +) +def test_select_statement_individual_complexity_stat( + mock_analyzer, attribute, value, expected_stat +): + from_ = mock.create_autospec(Selectable) + from_.pre_actions = None + plan_node = SelectStatement( + from_=mock.create_autospec(Selectable), analyzer=mock_analyzer + ) + setattr(plan_node, attribute, value) + assert plan_node.individual_complexity_stat == expected_stat + + +def test_select_table_function_individual_complexity_stat(): + pass + + +def test_set_statement_individual_complexity_stat(): + pass + + +def test_snowflake_create_table(): + pass From 8607aa4182bd7592e3d4da521febc2f4fe94d2ce Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 5 Jun 2024 19:49:19 -0700 Subject: [PATCH 26/37] add unit test --- .../_internal/analyzer/select_statement.py | 10 ++- tests/unit/test_query_plan_analysis.py | 78 +++++++++++++++---- 2 files changed, 70 insertions(+), 18 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index a5ff71222a9..62f20aec486 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -702,7 +702,15 @@ def individual_complexity_stat(self) -> Counter[str]: # projection component stat += ( sum( - (expr.cumulative_complexity_stat for expr in self.projection), Counter() + ( + getattr( + expr, + "cumulative_complexity_stat", + Counter({PlanNodeCategory.COLUMN.value: 1}), + ) + for expr in self.projection + ), + Counter(), ) if self.projection else Counter() diff --git a/tests/unit/test_query_plan_analysis.py b/tests/unit/test_query_plan_analysis.py index 669db905155..c8e92ca4da5 100644 --- a/tests/unit/test_query_plan_analysis.py +++ b/tests/unit/test_query_plan_analysis.py @@ -7,6 +7,12 @@ import pytest +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + EXCEPT, + INTERSECT, + UNION, + UNION_ALL, +) from snowflake.snowpark._internal.analyzer.expression import ( Attribute, Expression, @@ -22,10 +28,12 @@ SelectSQL, SelectStatement, SelectTableFunction, + SetOperand, SetStatement, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan +from snowflake.snowpark._internal.analyzer.table_function import TableFunctionExpression from snowflake.snowpark._internal.analyzer.unary_plan_node import Project from snowflake.snowpark.types import IntegerType @@ -134,11 +142,20 @@ def test_select_snowflake_plan_individual_complexity_stat( @pytest.mark.parametrize( "attribute,value,expected_stat", [ - ("projection", [NamedExpression()], {}), - ("order_by", [Expression()], {}), - ("where", Expression(), {}), - ("limit_", 10, {}), - ("offset", 2, {}), + ("projection", [NamedExpression()], {PlanNodeCategory.COLUMN.value: 1}), + ("projection", [Expression()], {PlanNodeCategory.OTHERS.value: 1}), + ( + "order_by", + [Expression()], + {PlanNodeCategory.OTHERS.value: 1, PlanNodeCategory.ORDER_BY.value: 1}, + ), + ( + "where", + Expression(), + {PlanNodeCategory.OTHERS.value: 1, PlanNodeCategory.FILTER.value: 1}, + ), + ("limit_", 10, {PlanNodeCategory.LOW_IMPACT.value: 1}), + ("offset", 2, {PlanNodeCategory.LOW_IMPACT.value: 1}), ], ) def test_select_statement_individual_complexity_stat( @@ -146,20 +163,47 @@ def test_select_statement_individual_complexity_stat( ): from_ = mock.create_autospec(Selectable) from_.pre_actions = None - plan_node = SelectStatement( - from_=mock.create_autospec(Selectable), analyzer=mock_analyzer - ) + from_.post_actions = None + from_.expr_to_alias = {} + from_.df_aliased_col_name_to_real_col_name = {} + + plan_node = SelectStatement(from_=from_, analyzer=mock_analyzer) setattr(plan_node, attribute, value) assert plan_node.individual_complexity_stat == expected_stat -def test_select_table_function_individual_complexity_stat(): - pass - - -def test_set_statement_individual_complexity_stat(): - pass - +def test_select_table_function_individual_complexity_stat( + mock_analyzer, mock_session, mock_query +): + func_expr = mock.create_autospec(TableFunctionExpression) + source_plan = Project([NamedExpression(), NamedExpression()], LogicalPlan()) + snowflake_plan = SnowflakePlan( + [mock_query], "", source_plan=source_plan, session=mock_session + ) -def test_snowflake_create_table(): - pass + def mocked_resolve(*args, **kwargs): + return snowflake_plan + + with mock.patch.object(mock_analyzer, "resolve", side_effect=mocked_resolve): + plan_node = SelectTableFunction(func_expr, analyzer=mock_analyzer) + assert plan_node.individual_complexity_stat == { + PlanNodeCategory.COLUMN.value: 2 + } + + +@pytest.mark.parametrize("set_operator", [UNION, UNION_ALL, INTERSECT, EXCEPT]) +def test_set_statement_individual_complexity_stat(mock_analyzer, set_operator): + mock_selectable = mock.create_autospec(Selectable) + mock_selectable.pre_actions = None + mock_selectable.post_actions = None + mock_selectable.expr_to_alias = {} + mock_selectable.df_aliased_col_name_to_real_col_name = {} + set_operands = [ + SetOperand(mock_selectable, set_operator), + SetOperand(mock_selectable, set_operator), + ] + plan_node = SetStatement(*set_operands, analyzer=mock_analyzer) + + assert plan_node.individual_complexity_stat == { + PlanNodeCategory.SET_OPERATION.value: 1 + } From 383082019f4c485a1b43965f2eac42fc919295c5 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 6 Jun 2024 08:07:16 -0700 Subject: [PATCH 27/37] fix typing --- src/snowflake/snowpark/_internal/analyzer/select_statement.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 62f20aec486..c69c447f9d5 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -707,7 +707,7 @@ def individual_complexity_stat(self) -> Counter[str]: expr, "cumulative_complexity_stat", Counter({PlanNodeCategory.COLUMN.value: 1}), - ) + ) # type: ignore for expr in self.projection ), Counter(), From 68855ac4dbeb62f880fe68514b52289b73a27299 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 6 Jun 2024 08:43:16 -0700 Subject: [PATCH 28/37] align with doc --- .../snowpark/_internal/analyzer/unary_expression.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index bf55421e6ee..23d0379e1b9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -10,6 +10,7 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + Counter, PlanNodeCategory, ) from snowflake.snowpark.types import DataType @@ -102,5 +103,6 @@ def __init__(self, child: Expression) -> None: self.name = child.sql @property - def plan_node_category(self) -> PlanNodeCategory: - return PlanNodeCategory.OTHERS + def individual_complexity_stat(self) -> Counter[str]: + # this is a wrapper around child + return Counter() From 1299ee96829a9726ae23d754d67bc0dedec9c4a9 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 12 Jun 2024 08:29:30 -0700 Subject: [PATCH 29/37] fix SelectSQL --- .../_internal/analyzer/select_statement.py | 12 ++---------- tests/integ/scala/test_async_job_suite.py | 16 +++------------- tests/integ/test_query_plan_analysis.py | 6 +++--- 3 files changed, 8 insertions(+), 26 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index c69c447f9d5..6606262c9e6 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -426,16 +426,8 @@ def schema_query(self) -> str: return self._schema_query @property - def individual_complexity_stat(self) -> Counter[str]: - if self.pre_actions: - # Currently having pre-actions implies we have a non-select query followed - # by a SELECT * FROM table(result_scan(query_id)) statement - return Counter({PlanNodeCategory.COLUMN.value: 1}) - - # no pre-action implies the best stat we have is of # active columns - return Counter( - {PlanNodeCategory.COLUMN.value: len(self.column_states.active_columns)} - ) + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.COLUMN def to_subqueryable(self) -> "SelectSQL": """Convert this SelectSQL to a new one that can be used as a subquery. Refer to __init__.""" diff --git a/tests/integ/scala/test_async_job_suite.py b/tests/integ/scala/test_async_job_suite.py index 6e01cb7e00a..e78f7f8ae69 100644 --- a/tests/integ/scala/test_async_job_suite.py +++ b/tests/integ/scala/test_async_job_suite.py @@ -8,8 +8,6 @@ import pytest -from snowflake.snowpark.session import Session - try: import pandas as pd from pandas.testing import assert_frame_equal @@ -25,7 +23,7 @@ random_name_for_temp_object, ) from snowflake.snowpark.exceptions import SnowparkSQLException -from snowflake.snowpark.functions import col, sproc, when_matched, when_not_matched +from snowflake.snowpark.functions import col, when_matched, when_not_matched from snowflake.snowpark.table import DeleteResult, MergeResult, UpdateResult from snowflake.snowpark.types import ( DoubleType, @@ -351,21 +349,13 @@ def test_async_batch_insert(session): reason="TODO(SNOW-932722): Cancel query is not allowed in stored proc", ) def test_async_is_running_and_cancel(session): - # creating a sproc here because describe query on SYSTEM$WAIT() - # triggers the wait and the async job fails because we don't hit - # the correct time boundaries - @sproc(name="wait_sproc", packages=["snowflake-snowpark-python"]) - def wait(_: Session, sec: int) -> str: - sleep(sec) - return "success" - - async_job = session.sql("call wait_sproc(3)").collect_nowait() + async_job = session.sql("select SYSTEM$WAIT(3)").collect_nowait() while not async_job.is_done(): sleep(1.0) assert async_job.is_done() # set 20s to avoid flakiness - async_job2 = session.sql("call wait_sproc(20)").collect_nowait() + async_job2 = session.sql("select SYSTEM$WAIT(20)").collect_nowait() assert not async_job2.is_done() async_job2.cancel() start = time() diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 41752d5e994..e4c9b4d266f 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -129,7 +129,7 @@ def test_join_table_function(session: Session): "select 'James' as name, 'address1 address2 address3' as addresses" ) # SelectSQL chooses num active columns as the best estimate - assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN.value: 2}) + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN.value: 1}) split_to_table = table_function("split_to_table") @@ -141,7 +141,7 @@ def test_join_table_function(session: Session): assert_df_subtree_query_complexity( df2, { - PlanNodeCategory.COLUMN.value: 9, + PlanNodeCategory.COLUMN.value: 8, PlanNodeCategory.JOIN.value: 1, PlanNodeCategory.FUNCTION.value: 1, PlanNodeCategory.LITERAL.value: 1, @@ -158,7 +158,7 @@ def test_join_table_function(session: Session): assert_df_subtree_query_complexity( df3, { - PlanNodeCategory.COLUMN.value: 7, + PlanNodeCategory.COLUMN.value: 6, PlanNodeCategory.JOIN.value: 1, PlanNodeCategory.FUNCTION.value: 1, PlanNodeCategory.LITERAL.value: 1, From 4a901f41df05dd5b6b23edce8ff11475501df834 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 12 Jun 2024 09:50:46 -0700 Subject: [PATCH 30/37] rename to node_complexity; add setter for cumulative complexity expression --- .../_internal/analyzer/binary_plan_node.py | 6 +- .../snowpark/_internal/analyzer/expression.py | 124 +++++++++--------- .../_internal/analyzer/grouping_set.py | 8 +- .../_internal/analyzer/select_statement.py | 38 +++--- .../_internal/analyzer/snowflake_plan.py | 24 ++-- .../_internal/analyzer/snowflake_plan_node.py | 38 +++--- .../_internal/analyzer/sort_expression.py | 6 +- .../_internal/analyzer/table_function.py | 52 ++++---- .../analyzer/table_merge_expression.py | 46 +++---- .../_internal/analyzer/unary_expression.py | 2 +- .../_internal/analyzer/unary_plan_node.py | 40 +++--- .../_internal/analyzer/window_expression.py | 41 +++--- src/snowflake/snowpark/_internal/telemetry.py | 6 +- tests/integ/test_query_plan_analysis.py | 30 ++--- tests/integ/test_telemetry.py | 14 +- tests/unit/test_query_plan_analysis.py | 50 +++---- 16 files changed, 252 insertions(+), 273 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index 17d8f34464c..a3d256233c1 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -196,7 +196,7 @@ def sql(self) -> str: return self.join_type.sql @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond stat = Counter({PlanNodeCategory.JOIN.value: 1}) if isinstance(self.join_type, UsingJoin) and self.join_type.using_columns: @@ -204,12 +204,12 @@ def individual_complexity_stat(self) -> Counter[str]: {PlanNodeCategory.COLUMN.value: len(self.join_type.using_columns)} ) stat += ( - self.join_condition.cumulative_complexity_stat + self.join_condition.cumulative_node_complexity if self.join_condition else Counter() ) stat += ( - self.match_condition.cumulative_complexity_stat + self.match_condition.cumulative_node_complexity if self.match_condition else Counter() ) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 6a19a978782..b674406dc7e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -4,7 +4,6 @@ import copy import uuid -from functools import cached_property from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple import snowflake.snowpark._internal.utils @@ -63,6 +62,7 @@ def __init__(self, child: Optional["Expression"] = None) -> None: self.nullable = True self.children = [child] if child else None self.datatype: Optional[DataType] = None + self._cumulative_node_complexity: Optional[Counter] = None def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. @@ -85,30 +85,41 @@ def sql(self) -> str: ) return f"{self.pretty_name}({children_sql})" + def __str__(self) -> str: + return self.pretty_name + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.OTHERS @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: """Returns the individual contribution of the expression node towards the overall compilation complexity of the generated sql. """ return Counter({self.plan_node_category.value: 1}) - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: - """Returns the aggregate sum complexity statistic from the subtree rooted at this - expression node. Statistic of current node is included in the final aggregate. - """ + def calculate_cumulative_node_complexity(self): children = self.children or [] return sum( - (child.cumulative_complexity_stat for child in children), - self.individual_complexity_stat, + (child.cumulative_node_complexity for child in children), + self.individual_node_complexity, ) - def __str__(self) -> str: - return self.pretty_name + @property + def cumulative_node_complexity(self) -> Counter[str]: + """Returns the aggregate sum complexity statistic from the subtree rooted at this + expression node. Statistic of current node is included in the final aggregate. + """ + if self._cumulative_node_complexity is None: + self._cumulative_node_complexity = ( + self.calculate_cumulative_node_complexity() + ) + return self._cumulative_node_complexity + + @cumulative_node_complexity.setter + def cumulative_node_complexity(self, value: Counter[str]): + self._cumulative_node_complexity = value class NamedExpression: @@ -135,9 +146,8 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: - return self.plan.cumulative_complexity_stat + def calculate_cumulative_node_complexity(self) -> Counter[str]: + return self.plan.cumulative_node_complexity class MultipleExpression(Expression): @@ -148,10 +158,9 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: return sum( - (expr.cumulative_complexity_stat for expr in self.expressions), + (expr.cumulative_node_complexity for expr in self.expressions), Counter(), ) @@ -169,13 +178,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: return ( - self.columns.cumulative_complexity_stat - + self.individual_complexity_stat + self.columns.cumulative_node_complexity + + self.individual_node_complexity + sum( - (expr.cumulative_complexity_stat for expr in self.values), + (expr.cumulative_node_complexity for expr in self.values), Counter(), ) ) @@ -225,16 +233,15 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: if self.expressions: return Counter() # if there are no expressions, we assign column value = 1 to Star return Counter({PlanNodeCategory.COLUMN.value: 1}) - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: - return self.individual_complexity_stat + sum( - (child.individual_complexity_stat for child in self.expressions), + def calculate_cumulative_node_complexity(self) -> Counter[str]: + return self.individual_node_complexity + sum( + (child.individual_node_complexity for child in self.expressions), Counter(), ) @@ -370,12 +377,11 @@ def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern return PlanNodeCategory.LOW_IMPACT - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: return ( - self.expr.cumulative_complexity_stat - + self.pattern.cumulative_complexity_stat - + self.individual_complexity_stat + self.expr.cumulative_node_complexity + + self.pattern.cumulative_node_complexity + + self.individual_node_complexity ) @@ -393,12 +399,11 @@ def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern return PlanNodeCategory.LOW_IMPACT - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: return ( - self.expr.cumulative_complexity_stat - + self.pattern.cumulative_complexity_stat - + self.individual_complexity_stat + self.expr.cumulative_node_complexity + + self.pattern.cumulative_node_complexity + + self.individual_node_complexity ) @@ -416,9 +421,8 @@ def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec return PlanNodeCategory.LOW_IMPACT - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: - return self.expr.cumulative_complexity_stat + self.individual_complexity_stat + def calculate_cumulative_node_complexity(self) -> Counter[str]: + return self.expr.cumulative_node_complexity + self.individual_node_complexity class SubfieldString(Expression): @@ -435,10 +439,9 @@ def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field return PlanNodeCategory.LITERAL - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: # self.expr ( self.field ) - return self.expr.cumulative_complexity_stat + self.individual_complexity_stat + return self.expr.cumulative_node_complexity + self.individual_node_complexity class SubfieldInt(Expression): @@ -455,10 +458,9 @@ def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field return PlanNodeCategory.LITERAL - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: # self.expr ( self.field ) - return self.expr.cumulative_complexity_stat + self.individual_complexity_stat + return self.expr.cumulative_node_complexity + self.individual_node_complexity class FunctionExpression(Expression): @@ -512,15 +514,14 @@ def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) return PlanNodeCategory.ORDER_BY - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: return ( sum( - (col.cumulative_complexity_stat for col in self.order_by_cols), + (col.cumulative_node_complexity for col in self.order_by_cols), Counter(), ) - + self.individual_complexity_stat - + self.expr.cumulative_complexity_stat + + self.individual_node_complexity + + self.expr.cumulative_node_complexity ) @@ -546,17 +547,16 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.CASE_WHEN - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: - stat = self.individual_complexity_stat + sum( + def calculate_cumulative_node_complexity(self) -> Counter[str]: + stat = self.individual_node_complexity + sum( ( - condition.cumulative_complexity_stat + value.cumulative_complexity_stat + condition.cumulative_node_complexity + value.cumulative_node_complexity for condition, value in self.branches ), Counter(), ) stat += ( - self.else_value.cumulative_complexity_stat if self.else_value else Counter() + self.else_value.cumulative_node_complexity if self.else_value else Counter() ) return stat @@ -584,11 +584,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: return sum( - (expr.cumulative_complexity_stat for expr in self.children), - self.individual_complexity_stat, + (expr.cumulative_node_complexity for expr in self.children), + self.individual_node_complexity, ) @@ -606,6 +605,5 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: - return self.col.cumulative_complexity_stat + self.individual_complexity_stat + def calculate_cumulative_node_complexity(self) -> Counter[str]: + return self.col.cumulative_node_complexity + self.individual_node_complexity diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index c3e4e7f026c..6b566e0ff0d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -2,7 +2,6 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from functools import cached_property from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( @@ -46,14 +45,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: return sum( ( - sum((expr.cumulative_complexity_stat for expr in arg), Counter()) + sum((expr.cumulative_node_complexity for expr in arg), Counter()) for arg in self.args ), - self.individual_complexity_stat, + self.individual_node_complexity, ) @property diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 6606262c9e6..866ed20d9d8 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -203,7 +203,7 @@ def __init__( str, Dict[str, str] ] = defaultdict(dict) self._api_calls = api_calls.copy() if api_calls is not None else None - self._cumulative_complexity_stat: Optional[Counter[str]] = None + self._cumulative_node_complexity: Optional[Counter[str]] = None def __eq__(self, other: "Selectable") -> bool: if self._id is not None and other._id is not None: @@ -295,17 +295,17 @@ def num_duplicate_nodes(self) -> int: return self.snowflake_plan.num_duplicate_nodes @property - def cumulative_complexity_stat(self) -> Counter[str]: - if self._cumulative_complexity_stat is None: - stat = self.individual_complexity_stat + def cumulative_node_complexity(self) -> Counter[str]: + if self._cumulative_node_complexity is None: + stat = self.individual_node_complexity for node in self.children_plan_nodes: - stat += node.cumulative_complexity_stat - self._cumulative_complexity_stat = stat - return self._cumulative_complexity_stat + stat += node.cumulative_node_complexity + self._cumulative_node_complexity = stat + return self._cumulative_node_complexity - @cumulative_complexity_stat.setter - def cumulative_complexity_stat(self, value: Counter[str]): - self._cumulative_complexity_stat = value + @cumulative_node_complexity.setter + def cumulative_node_complexity(self, value: Counter[str]): + self._cumulative_node_complexity = value @property def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: @@ -492,8 +492,8 @@ def query_params(self) -> Optional[Sequence[Any]]: return self._query_params @property - def individual_complexity_stat(self) -> Counter[str]: - return self.snowflake_plan.individual_complexity_stat + def individual_node_complexity(self) -> Counter[str]: + return self.snowflake_plan.individual_node_complexity class SelectStatement(Selectable): @@ -689,7 +689,7 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return [self.from_] @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: stat = Counter() # projection component stat += ( @@ -697,7 +697,7 @@ def individual_complexity_stat(self) -> Counter[str]: ( getattr( expr, - "cumulative_complexity_stat", + "cumulative_node_complexity", Counter({PlanNodeCategory.COLUMN.value: 1}), ) # type: ignore for expr in self.projection @@ -711,7 +711,7 @@ def individual_complexity_stat(self) -> Counter[str]: # filter component - add +1 for WHERE clause and sum of expression complexity for where expression stat += ( Counter({PlanNodeCategory.FILTER.value: 1}) - + self.where.cumulative_complexity_stat + + self.where.cumulative_node_complexity if self.where else Counter() ) @@ -719,7 +719,7 @@ def individual_complexity_stat(self) -> Counter[str]: # order by component - add complexity for each sort expression stat += ( sum( - (expr.cumulative_complexity_stat for expr in self.order_by), + (expr.cumulative_node_complexity for expr in self.order_by), Counter({PlanNodeCategory.ORDER_BY.value: 1}), ) if self.order_by @@ -1042,8 +1042,8 @@ def query_params(self) -> Optional[Sequence[Any]]: return self.snowflake_plan.queries[-1].params @property - def individual_complexity_stat(self) -> Counter[str]: - return self.snowflake_plan.individual_complexity_stat + def individual_node_complexity(self) -> Counter[str]: + return self.snowflake_plan.individual_node_complexity class SetOperand: @@ -1125,7 +1125,7 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return self._nodes @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # we add #set_operands - 1 additional operators in sql query return Counter( {PlanNodeCategory.SET_OPERATION.value: len(self.set_operands) - 1} diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index a93f3d68e41..0023f9aec99 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -233,7 +233,7 @@ def __init__( self.placeholder_query = placeholder_query # encode an id for CTE optimization self._id = encode_id(queries[-1].sql, queries[-1].params) - self._cumulative_complexity_stat: Optional[Counter[str]] = None + self._cumulative_node_complexity: Optional[Counter[str]] = None def __eq__(self, other: "SnowflakePlan") -> bool: if self._id is not None and other._id is not None: @@ -353,23 +353,23 @@ def num_duplicate_nodes(self) -> int: return len(find_duplicate_subtrees(self)) @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: if self.source_plan: - return self.source_plan.individual_complexity_stat + return self.source_plan.individual_node_complexity return Counter() @property - def cumulative_complexity_stat(self) -> Counter[str]: - if self._cumulative_complexity_stat is None: - stat = self.individual_complexity_stat + def cumulative_node_complexity(self) -> Counter[str]: + if self._cumulative_node_complexity is None: + stat = self.individual_node_complexity for node in self.children_plan_nodes: - stat += node.cumulative_complexity_stat - self._cumulative_complexity_stat = stat - return self._cumulative_complexity_stat + stat += node.cumulative_node_complexity + self._cumulative_node_complexity = stat + return self._cumulative_node_complexity - @cumulative_complexity_stat.setter - def cumulative_complexity_stat(self, value: Counter[str]): - self._cumulative_complexity_stat = value + @cumulative_node_complexity.setter + def cumulative_node_complexity(self, value: Counter[str]): + self._cumulative_node_complexity = value def __copy__(self) -> "SnowflakePlan": if self.session._cte_optimization_enabled: diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 0fcc6a3d6d6..a1f99926e49 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -27,35 +27,35 @@ class LogicalPlan: def __init__(self) -> None: self.children = [] - self._cumulative_complexity_stat: Optional[Counter[str]] = None + self._cumulative_node_complexity: Optional[Counter[str]] = None @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.OTHERS @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: """Returns the individual contribution of the logical plan node towards the overall compilation complexity of the generated sql. """ return Counter({self.plan_node_category.value: 1}) @property - def cumulative_complexity_stat(self) -> Counter[str]: + def cumulative_node_complexity(self) -> Counter[str]: """Returns the aggregate sum complexity statistic from the subtree rooted at this logical plan node. Statistic of current node is included in the final aggregate. """ - if self._cumulative_complexity_stat is None: - stat = self.individual_complexity_stat + if self._cumulative_node_complexity is None: + stat = self.individual_node_complexity for node in self.children: - stat += node.cumulative_complexity_stat + stat += node.cumulative_node_complexity - self._cumulative_complexity_stat = stat - return self._cumulative_complexity_stat + self._cumulative_node_complexity = stat + return self._cumulative_node_complexity - @cumulative_complexity_stat.setter - def cumulative_complexity_stat(self, value: Counter[str]): - self._cumulative_complexity_stat = value + @cumulative_node_complexity.setter + def cumulative_node_complexity(self, value: Counter[str]): + self._cumulative_node_complexity = value class LeafNode(LogicalPlan): @@ -73,7 +73,7 @@ def __init__(self, start: int, end: int, step: int, num_slices: int = 1) -> None self.num_slices = num_slices @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (step) + (start) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => count))) return Counter( { @@ -92,7 +92,7 @@ def __init__(self, name: str) -> None: self.name = name @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # SELECT * FROM name return Counter({PlanNodeCategory.COLUMN.value: 1}) @@ -110,7 +110,7 @@ def __init__( self.schema_query = schema_query @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) # TODO: use ARRAY_BIND_THRESHOLD return Counter( @@ -151,7 +151,7 @@ def __init__( self.comment = comment @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (query) stat = Counter({PlanNodeCategory.COLUMN.value: 1}) stat += ( @@ -161,7 +161,7 @@ def individual_complexity_stat(self) -> Counter[str]: ) stat += ( sum( - (expr.cumulative_complexity_stat for expr in self.clustering_exprs), + (expr.cumulative_node_complexity for expr in self.clustering_exprs), Counter(), ) if self.clustering_exprs @@ -181,12 +181,12 @@ def __init__( self.children.append(child) @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # for limit and offset return ( Counter({PlanNodeCategory.LOW_IMPACT.value: 2}) - + self.limit_expr.cumulative_complexity_stat - + self.offset_expr.cumulative_complexity_stat + + self.limit_expr.cumulative_node_complexity + + self.offset_expr.cumulative_node_complexity ) diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index e766d1ed027..aeb20880530 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,7 +2,6 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from functools import cached_property from typing import AbstractSet, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( @@ -58,6 +57,5 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: - return self.child.cumulative_complexity_stat + self.individual_complexity_stat + def calculate_cumulative_node_complexity(self) -> Counter[str]: + return self.child.cumulative_node_complexity + self.individual_node_complexity diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 24f47900f4b..19e34ae33bd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -3,7 +3,6 @@ # import sys -from functools import cached_property from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression @@ -35,14 +34,13 @@ def __init__( self.partition_spec = partition_spec self.order_spec = order_spec - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: if not self.over: return Counter() stat = Counter({PlanNodeCategory.WINDOW.value: 1}) stat += ( sum( - (expr.cumulative_complexity_stat for expr in self.partition_spec), + (expr.cumulative_node_complexity for expr in self.partition_spec), Counter({PlanNodeCategory.PARTITION_BY.value: 1}), ) if self.partition_spec @@ -50,7 +48,7 @@ def cumulative_complexity_stat(self) -> Counter[str]: ) stat += ( sum( - (expr.cumulative_complexity_stat for expr in self.order_spec), + (expr.cumulative_node_complexity for expr in self.order_spec), Counter({PlanNodeCategory.ORDER_BY.value: 1}), ) if self.order_spec @@ -89,9 +87,8 @@ def __init__( self.recursive = recursive self.mode = mode - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: - return self.individual_complexity_stat + self.input.cumulative_complexity_stat + def calculate_cumulative_node_complexity(self) -> Counter[str]: + return self.individual_node_complexity + self.input.cumulative_node_complexity class PosArgumentsTableFunction(TableFunctionExpression): @@ -104,14 +101,13 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: stat = sum( - (arg.cumulative_complexity_stat for arg in self.args), - self.individual_complexity_stat, + (arg.cumulative_node_complexity for arg in self.args), + self.individual_node_complexity, ) stat += ( - self.partition_spec.cumulative_complexity_stat + self.partition_spec.cumulative_node_complexity if self.partition_spec else Counter() ) @@ -128,14 +124,13 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: stat = sum( - (arg.cumulative_complexity_stat for arg in self.args.values()), - self.individual_complexity_stat, + (arg.cumulative_node_complexity for arg in self.args.values()), + self.individual_node_complexity, ) stat += ( - self.partition_spec.cumulative_complexity_stat + self.partition_spec.cumulative_node_complexity if self.partition_spec else Counter() ) @@ -148,14 +143,13 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: self.args = args self.operators = operators - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: stat = sum( - (arg.cumulative_complexity_stat for arg in self.args.values()), - self.individual_complexity_stat, + (arg.cumulative_node_complexity for arg in self.args.values()), + self.individual_node_complexity, ) stat += ( - self.partition_spec.cumulative_complexity_stat + self.partition_spec.cumulative_node_complexity if self.partition_spec else Counter() ) @@ -169,9 +163,9 @@ def __init__(self, table_function: TableFunctionExpression) -> None: self.table_function = table_function @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # SELECT * FROM table_function - return self.table_function.cumulative_complexity_stat + return self.table_function.cumulative_node_complexity class TableFunctionJoin(LogicalPlan): @@ -189,7 +183,7 @@ def __init__( self.right_cols = right_cols if right_cols is not None else ["*"] @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # SELECT left_cols, right_cols FROM child as left_alias JOIN table(func(...)) as right_alias return ( Counter( @@ -199,7 +193,7 @@ def individual_complexity_stat(self) -> Counter[str]: PlanNodeCategory.JOIN.value: 1, } ) - + self.table_function.cumulative_complexity_stat + + self.table_function.cumulative_node_complexity ) @@ -212,9 +206,9 @@ def __init__( self.table_function = table_function @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # SELECT * FROM (child), LATERAL table_func_expression return ( Counter({PlanNodeCategory.COLUMN.value: 1}) - + self.table_function.cumulative_complexity_stat + + self.table_function.cumulative_node_complexity ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index 94adc5c0159..c362abdd0ad 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -2,7 +2,6 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from functools import cached_property from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression @@ -25,12 +24,11 @@ def __init__(self, condition: Optional[Expression]) -> None: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: # WHEN MATCHED [AND condition] THEN DEL - stat = self.individual_complexity_stat + stat = self.individual_node_complexity stat += ( - self.condition.cumulative_complexity_stat if self.condition else Counter() + self.condition.cumulative_node_complexity if self.condition else Counter() ) return stat @@ -42,17 +40,16 @@ def __init__( super().__init__(condition) self.assignments = assignments - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) - stat = self.individual_complexity_stat + stat = self.individual_node_complexity stat += ( - self.condition.cumulative_complexity_stat if self.condition else Counter() + self.condition.cumulative_node_complexity if self.condition else Counter() ) stat += sum( ( - key_expr.cumulative_complexity_stat - + val_expr.cumulative_complexity_stat + key_expr.cumulative_node_complexity + + val_expr.cumulative_node_complexity for key_expr, val_expr in self.assignments.items() ), Counter(), @@ -75,15 +72,14 @@ def __init__( self.keys = keys self.values = values - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) - stat = self.individual_complexity_stat + stat = self.individual_node_complexity stat += ( - self.condition.cumulative_complexity_stat if self.condition else Counter() + self.condition.cumulative_node_complexity if self.condition else Counter() ) - stat += sum((key.cumulative_complexity_stat for key in self.keys), Counter()) - stat += sum((val.cumulative_complexity_stat for val in self.values), Counter()) + stat += sum((key.cumulative_node_complexity for key in self.keys), Counter()) + stat += sum((val.cumulative_node_complexity for val in self.values), Counter()) return stat @@ -103,17 +99,17 @@ def __init__( self.children = [source_data] if source_data else [] @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # UPDATE table_name SET COMMA.join(k, v in assignments) [source_data] [WHERE condition] stat = sum( ( - k.cumulative_complexity_stat + v.cumulative_complexity_stat + k.cumulative_node_complexity + v.cumulative_node_complexity for k, v in self.assignments.items() ), Counter(), ) stat += ( - self.condition.cumulative_complexity_stat if self.condition else Counter() + self.condition.cumulative_node_complexity if self.condition else Counter() ) return stat @@ -132,10 +128,10 @@ def __init__( self.children = [source_data] if source_data else [] @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # DELETE FROM table_name [USING source_data] [WHERE condition] return ( - self.condition.cumulative_complexity_stat if self.condition else Counter() + self.condition.cumulative_node_complexity if self.condition else Counter() ) @@ -155,8 +151,8 @@ def __init__( self.children = [source] if source else [] @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # MERGE INTO table_name USING (source) ON join_expr clauses - return self.join_expr.cumulative_complexity_stat + sum( - (clause.cumulative_complexity_stat for clause in self.clauses), Counter() + return self.join_expr.cumulative_node_complexity + sum( + (clause.cumulative_node_complexity for clause in self.clauses), Counter() ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index 23d0379e1b9..e6399df43ba 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -103,6 +103,6 @@ def __init__(self, child: Expression) -> None: self.name = child.sql @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # this is a wrapper around child return Counter() diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index fb0016d1918..47e7f40ea6d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -38,7 +38,7 @@ def __init__( self.seed = seed @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # SELECT * FROM (child) SAMPLE (probability) -- if probability is provided # SELECT * FROM (child) SAMPLE (row_count ROWS) -- if not probability but row count is provided return Counter( @@ -56,10 +56,10 @@ def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None: self.order = order @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # child ORDER BY COMMA.join(order) return Counter({PlanNodeCategory.ORDER_BY.value: 1}) + sum( - (col.cumulative_complexity_stat for col in self.order), Counter() + (col.cumulative_node_complexity for col in self.order), Counter() ) @@ -75,12 +75,12 @@ def __init__( self.aggregate_expressions = aggregate_expressions @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: stat = Counter() if self.grouping_expressions: # GROUP BY grouping_exprs stat += Counter({PlanNodeCategory.GROUP_BY.value: 1}) + sum( - (expr.cumulative_complexity_stat for expr in self.grouping_expressions), + (expr.cumulative_node_complexity for expr in self.grouping_expressions), Counter(), ) else: @@ -91,7 +91,7 @@ def individual_complexity_stat(self) -> Counter[str]: ( getattr( expr, - "cumulative_complexity_stat", + "cumulative_node_complexity", Counter({PlanNodeCategory.COLUMN.value: 1}), ) # type: ignore for expr in self.aggregate_expressions @@ -119,24 +119,24 @@ def __init__( self.default_on_null = default_on_null @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: stat = Counter() # child stat adjustment if grouping cols if self.grouping_columns and self.aggregates and self.aggregates[0].children: # for additional projecting cols when grouping cols is not empty stat += sum( - (col.cumulative_complexity_stat for col in self.grouping_columns), + (col.cumulative_node_complexity for col in self.grouping_columns), Counter(), ) - stat += self.pivot_column.cumulative_complexity_stat - stat += self.aggregates[0].children[0].cumulative_complexity_stat + stat += self.pivot_column.cumulative_node_complexity + stat += self.aggregates[0].children[0].cumulative_node_complexity # pivot col if isinstance(self.pivot_values, ScalarSubquery): - stat += self.pivot_values.cumulative_complexity_stat + stat += self.pivot_values.cumulative_node_complexity elif isinstance(self.pivot_values, List): stat += sum( - (val.cumulative_complexity_stat for val in self.pivot_values), Counter() + (val.cumulative_node_complexity for val in self.pivot_values), Counter() ) else: # if pivot values is None, then we add OTHERS for ANY @@ -144,7 +144,7 @@ def individual_complexity_stat(self) -> Counter[str]: # aggregate stat stat += sum( - (expr.cumulative_complexity_stat for expr in self.aggregates), Counter() + (expr.cumulative_node_complexity for expr in self.aggregates), Counter() ) # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) @@ -168,13 +168,13 @@ def __init__( self.column_list = column_list @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) stat = Counter( {PlanNodeCategory.UNPIVOT.value: 1, PlanNodeCategory.COLUMN.value: 3} ) stat += sum( - (expr.cumulative_complexity_stat for expr in self.column_list), Counter() + (expr.cumulative_node_complexity for expr in self.column_list), Counter() ) return stat @@ -189,7 +189,7 @@ def __init__( self.column_map = column_map @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # SELECT * RENAME (before AS after, ...) FROM child return Counter( { @@ -205,11 +205,11 @@ def __init__(self, condition: Expression, child: LogicalPlan) -> None: self.condition = condition @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # child WHERE condition return ( Counter({PlanNodeCategory.FILTER.value: 1}) - + self.condition.cumulative_complexity_stat + + self.condition.cumulative_node_complexity ) @@ -219,7 +219,7 @@ def __init__(self, project_list: List[NamedExpression], child: LogicalPlan) -> N self.project_list = project_list @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: if not self.project_list: return Counter({PlanNodeCategory.COLUMN.value: 1}) @@ -227,7 +227,7 @@ def individual_complexity_stat(self) -> Counter[str]: ( getattr( col, - "cumulative_complexity_stat", + "cumulative_node_complexity", Counter({PlanNodeCategory.COLUMN.value: 1}), ) # type: ignore for col in self.project_list diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index d58a40eaaa0..a2051355fd7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -2,7 +2,6 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from functools import cached_property from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( @@ -76,13 +75,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: # frame_type BETWEEN lower AND upper return ( - self.individual_complexity_stat - + self.lower.cumulative_complexity_stat - + self.upper.cumulative_complexity_stat + self.individual_node_complexity + + self.lower.cumulative_node_complexity + + self.upper.cumulative_node_complexity ) @@ -104,7 +102,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: ) @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: stat = Counter() stat += ( Counter({PlanNodeCategory.PARTITION_BY.value: 1}) @@ -118,19 +116,18 @@ def individual_complexity_stat(self) -> Counter[str]: ) return stat - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: # partition_spec order_by_spec frame_spec return ( - self.individual_complexity_stat + self.individual_node_complexity + sum( - (expr.cumulative_complexity_stat for expr in self.partition_spec), + (expr.cumulative_node_complexity for expr in self.partition_spec), Counter(), ) + sum( - (expr.cumulative_complexity_stat for expr in self.order_spec), Counter() + (expr.cumulative_node_complexity for expr in self.order_spec), Counter() ) - + self.frame_spec.cumulative_complexity_stat + + self.frame_spec.cumulative_node_complexity ) @@ -149,13 +146,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: # window_function OVER ( window_spec ) return ( - self.window_function.cumulative_complexity_stat - + self.window_spec.cumulative_complexity_stat - + self.individual_complexity_stat + self.window_function.cumulative_node_complexity + + self.window_spec.cumulative_node_complexity + + self.individual_node_complexity ) @@ -179,7 +175,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) @property - def individual_complexity_stat(self) -> Counter[str]: + def individual_node_complexity(self) -> Counter[str]: # for func_name stat = Counter({PlanNodeCategory.FUNCTION.value: 1}) # for offset @@ -194,11 +190,10 @@ def individual_complexity_stat(self) -> Counter[str]: ) return stat - @cached_property - def cumulative_complexity_stat(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Counter[str]: # func_name (expr [, offset] [, default]) [IGNORE NULLS] - stat = self.individual_complexity_stat + self.expr.cumulative_complexity_stat - stat += self.default.cumulative_complexity_stat if self.default else Counter() + stat = self.individual_node_complexity + self.expr.cumulative_node_complexity + stat += self.default.cumulative_node_complexity if self.default else Counter() return stat diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 48cc53db23f..d18ae0e6c55 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -68,7 +68,7 @@ class TelemetryField(Enum): # dataframe query stats QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" - QUERY_PLAN_STAT = "query_plan_stat" + QUERY_PLAN_COMPLEXITY = "query_plan_complexity" # These DataFrame APIs call other DataFrame APIs @@ -161,8 +161,8 @@ def wrap(*args, **kwargs): api_calls[0][ TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value ] = plan.num_duplicate_nodes - api_calls[0][TelemetryField.QUERY_PLAN_STAT.value] = dict( - plan.cumulative_complexity_stat + api_calls[0][TelemetryField.QUERY_PLAN_COMPLEXITY.value] = dict( + plan.cumulative_node_complexity ) except Exception: pass diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index e4c9b4d266f..dae7b5338a1 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -51,13 +51,13 @@ def sample_table(session): Utils.drop_table(session, table_name) -def get_cumulative_complexity_stat(df: DataFrame) -> Counter[str]: - return df._plan.cumulative_complexity_stat +def get_cumulative_node_complexity(df: DataFrame) -> Counter[str]: + return df._plan.cumulative_node_complexity def assert_df_subtree_query_complexity(df: DataFrame, estimate: Counter[str]): assert ( - get_cumulative_complexity_stat(df) == estimate + get_cumulative_node_complexity(df) == estimate ), f"query = {df.queries['queries'][-1]}" @@ -113,7 +113,7 @@ def test_generator_table_function(session: Session): # adds SELECT * from () ORDER BY seq ASC NULLS FIRST assert_df_subtree_query_complexity( df2, - get_cumulative_complexity_stat(df1) + get_cumulative_node_complexity(df1) + Counter( { PlanNodeCategory.ORDER_BY.value: 1, @@ -277,7 +277,7 @@ def test_window_function(session: Session): df2 = df1.select(avg("value").over(window2).as_("window2")) assert_df_subtree_query_complexity( df2, - get_cumulative_complexity_stat(df1) + get_cumulative_node_complexity(df1) + Counter( { PlanNodeCategory.ORDER_BY.value: 1, @@ -321,7 +321,7 @@ def test_join_statement(session: Session, sample_table: str): # SELECT * FROM ((ch1) AS SNOWPARK_LEFT INNER JOIN ( ch2) AS SNOWPARK_RIGHT ON (("l_k7b8_A" = "r_e09m_A") AND ("l_k7b8_B" = "r_e09m_B"))) assert_df_subtree_query_complexity( df4, - get_cumulative_complexity_stat(df3) + get_cumulative_node_complexity(df3) + Counter( {PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.LOW_IMPACT.value: 3} ), @@ -331,7 +331,7 @@ def test_join_statement(session: Session, sample_table: str): # SELECT * FROM ( (ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT USING (a, b)) assert_df_subtree_query_complexity( df5, - get_cumulative_complexity_stat(df3) + get_cumulative_node_complexity(df3) + Counter({PlanNodeCategory.COLUMN.value: 2}), ) @@ -453,7 +453,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df4 = df3.sort(col("b").asc()) assert_df_subtree_query_complexity( df4, - get_cumulative_complexity_stat(df3) + get_cumulative_node_complexity(df3) + Counter( { PlanNodeCategory.COLUMN.value: 1, @@ -467,7 +467,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df5 = df4.sort(col("c").desc()) assert_df_subtree_query_complexity( df5, - get_cumulative_complexity_stat(df4) + get_cumulative_node_complexity(df4) + Counter({PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}), ) @@ -476,7 +476,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df6 = df5.filter(col("b") > 2) assert_df_subtree_query_complexity( df6, - get_cumulative_complexity_stat(df5) + get_cumulative_node_complexity(df5) + Counter( { PlanNodeCategory.FILTER.value: 1, @@ -491,7 +491,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df7 = df6.filter(col("c") > 3) assert_df_subtree_query_complexity( df7, - get_cumulative_complexity_stat(df6) + get_cumulative_node_complexity(df6) + Counter( { PlanNodeCategory.COLUMN.value: 1, @@ -506,7 +506,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df8, sum( - (get_cumulative_complexity_stat(df) for df in [df3, df4, df5]), + (get_cumulative_node_complexity(df) for df in [df3, df4, df5]), Counter({PlanNodeCategory.SET_OPERATION.value: 2}), ), ) @@ -516,7 +516,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df9, sum( - (get_cumulative_complexity_stat(df) for df in [df6, df7, df8]), + (get_cumulative_node_complexity(df) for df in [df6, df7, df8]), Counter({PlanNodeCategory.SET_OPERATION.value: 2}), ), ) @@ -525,7 +525,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df10 = df9.limit(2) assert_df_subtree_query_complexity( df10, - get_cumulative_complexity_stat(df9) + get_cumulative_node_complexity(df9) + Counter({PlanNodeCategory.LOW_IMPACT.value: 1}), ) @@ -533,6 +533,6 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df11 = df9.limit(3, offset=1) assert_df_subtree_query_complexity( df11, - get_cumulative_complexity_stat(df9) + get_cumulative_node_complexity(df9) + Counter({PlanNodeCategory.LOW_IMPACT.value: 2}), ) diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index 71ce5084144..40ec07d9de7 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -593,7 +593,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_stat": { + "query_plan_complexity": { "filter": 1, "low_impact": 5, "column": 3, @@ -614,7 +614,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_stat": { + "query_plan_complexity": { "filter": 1, "low_impact": 5, "column": 3, @@ -635,7 +635,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_stat": { + "query_plan_complexity": { "filter": 1, "low_impact": 5, "column": 3, @@ -656,7 +656,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_stat": { + "query_plan_complexity": { "filter": 1, "low_impact": 5, "column": 3, @@ -677,7 +677,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, - "query_plan_stat": { + "query_plan_complexity": { "filter": 1, "low_impact": 5, "column": 3, @@ -821,7 +821,7 @@ def test_dataframe_stat_functions_api_calls(session): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, - "query_plan_stat": {"group_by": 1, "column": 6, "literal": 48}, + "query_plan_complexity": {"group_by": 1, "column": 6, "literal": 48}, }, { "name": "DataFrameStatFunctions.crosstab", @@ -839,7 +839,7 @@ def test_dataframe_stat_functions_api_calls(session): "sql_simplifier_enabled": session.sql_simplifier_enabled, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, - "query_plan_stat": {"group_by": 1, "column": 6, "literal": 48}, + "query_plan_complexity": {"group_by": 1, "column": 6, "literal": 48}, } ] diff --git a/tests/unit/test_query_plan_analysis.py b/tests/unit/test_query_plan_analysis.py index c8e92ca4da5..73603b06cd9 100644 --- a/tests/unit/test_query_plan_analysis.py +++ b/tests/unit/test_query_plan_analysis.py @@ -39,7 +39,7 @@ @pytest.mark.parametrize("node_type", [LogicalPlan, SnowflakePlan, Selectable]) -def test_assign_custom_cumulative_complexity_stat( +def test_assign_custom_cumulative_node_complexity( mock_session, mock_analyzer, mock_query, node_type ): def get_node_for_type(node_type): @@ -83,32 +83,32 @@ def set_children(node, node_type, children): set_children(nodes[5], node_type, []) set_children(nodes[6], node_type, []) - assert nodes[0].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 7} - assert nodes[1].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 5} - assert nodes[2].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} - assert nodes[3].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} - assert nodes[4].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 2} - assert nodes[5].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} - assert nodes[6].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[0].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 7} + assert nodes[1].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 5} + assert nodes[2].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[3].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[4].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 2} + assert nodes[5].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[6].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} - nodes[1].cumulative_complexity_stat = Counter({PlanNodeCategory.COLUMN.value: 1}) + nodes[1].cumulative_node_complexity = Counter({PlanNodeCategory.COLUMN.value: 1}) # assert that only value that is reset is changed - assert nodes[0].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 7} - assert nodes[1].cumulative_complexity_stat == {PlanNodeCategory.COLUMN.value: 1} - assert nodes[2].cumulative_complexity_stat == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[0].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 7} + assert nodes[1].cumulative_node_complexity == {PlanNodeCategory.COLUMN.value: 1} + assert nodes[2].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} -def test_selectable_entity_individual_complexity_stat(mock_analyzer): +def test_selectable_entity_individual_node_complexity(mock_analyzer): plan_node = SelectableEntity(entity_name="dummy entity", analyzer=mock_analyzer) - assert plan_node.individual_complexity_stat == {PlanNodeCategory.COLUMN.value: 1} + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN.value: 1} -def test_select_sql_individual_complexity_stat(mock_session, mock_analyzer): +def test_select_sql_individual_node_complexity(mock_session, mock_analyzer): plan_node = SelectSQL( "non-select statement", convert_to_select=True, analyzer=mock_analyzer ) - assert plan_node.individual_complexity_stat == {PlanNodeCategory.COLUMN.value: 1} + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN.value: 1} def mocked_get_result_attributes(sql): return [Attribute("A", IntegerType()), Attribute("B", IntegerType())] @@ -123,12 +123,12 @@ def mocked_analyze( ): with mock.patch.object(mock_analyzer, "analyze", side_effect=mocked_analyze): plan_node = SelectSQL("select 1 as A, 2 as B", analyzer=mock_analyzer) - assert plan_node.individual_complexity_stat == { + assert plan_node.individual_node_complexity == { PlanNodeCategory.COLUMN.value: 2 } -def test_select_snowflake_plan_individual_complexity_stat( +def test_select_snowflake_plan_individual_node_complexity( mock_session, mock_analyzer, mock_query ): source_plan = Project([NamedExpression(), NamedExpression()], LogicalPlan()) @@ -136,7 +136,7 @@ def test_select_snowflake_plan_individual_complexity_stat( [mock_query], "", source_plan=source_plan, session=mock_session ) plan_node = SelectSnowflakePlan(snowflake_plan, analyzer=mock_analyzer) - assert plan_node.individual_complexity_stat == {PlanNodeCategory.COLUMN.value: 2} + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN.value: 2} @pytest.mark.parametrize( @@ -158,7 +158,7 @@ def test_select_snowflake_plan_individual_complexity_stat( ("offset", 2, {PlanNodeCategory.LOW_IMPACT.value: 1}), ], ) -def test_select_statement_individual_complexity_stat( +def test_select_statement_individual_node_complexity( mock_analyzer, attribute, value, expected_stat ): from_ = mock.create_autospec(Selectable) @@ -169,10 +169,10 @@ def test_select_statement_individual_complexity_stat( plan_node = SelectStatement(from_=from_, analyzer=mock_analyzer) setattr(plan_node, attribute, value) - assert plan_node.individual_complexity_stat == expected_stat + assert plan_node.individual_node_complexity == expected_stat -def test_select_table_function_individual_complexity_stat( +def test_select_table_function_individual_node_complexity( mock_analyzer, mock_session, mock_query ): func_expr = mock.create_autospec(TableFunctionExpression) @@ -186,13 +186,13 @@ def mocked_resolve(*args, **kwargs): with mock.patch.object(mock_analyzer, "resolve", side_effect=mocked_resolve): plan_node = SelectTableFunction(func_expr, analyzer=mock_analyzer) - assert plan_node.individual_complexity_stat == { + assert plan_node.individual_node_complexity == { PlanNodeCategory.COLUMN.value: 2 } @pytest.mark.parametrize("set_operator", [UNION, UNION_ALL, INTERSECT, EXCEPT]) -def test_set_statement_individual_complexity_stat(mock_analyzer, set_operator): +def test_set_statement_individual_node_complexity(mock_analyzer, set_operator): mock_selectable = mock.create_autospec(Selectable) mock_selectable.pre_actions = None mock_selectable.post_actions = None @@ -204,6 +204,6 @@ def test_set_statement_individual_complexity_stat(mock_analyzer, set_operator): ] plan_node = SetStatement(*set_operands, analyzer=mock_analyzer) - assert plan_node.individual_complexity_stat == { + assert plan_node.individual_node_complexity == { PlanNodeCategory.SET_OPERATION.value: 1 } From a555e1ddacf1cc22c85a0f24a8dbe9a087485c08 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 12 Jun 2024 14:47:19 -0700 Subject: [PATCH 31/37] use Dict type hint instead of Counter --- .../_internal/analyzer/binary_plan_node.py | 34 +++-- .../snowpark/_internal/analyzer/expression.py | 138 +++++++++--------- .../_internal/analyzer/grouping_set.py | 14 +- .../analyzer/query_plan_analysis_utils.py | 23 ++- .../_internal/analyzer/select_statement.py | 78 +++++----- .../_internal/analyzer/snowflake_plan.py | 22 +-- .../_internal/analyzer/snowflake_plan_node.py | 89 ++++++----- .../_internal/analyzer/sort_expression.py | 12 +- .../_internal/analyzer/table_function.py | 112 +++++++------- .../analyzer/table_merge_expression.py | 89 ++++++----- .../_internal/analyzer/unary_expression.py | 7 +- .../_internal/analyzer/unary_plan_node.py | 127 ++++++++-------- .../_internal/analyzer/window_expression.py | 92 ++++++------ tests/integ/test_query_plan_analysis.py | 99 +++++++------ tests/unit/test_query_plan_analysis.py | 28 +--- 15 files changed, 490 insertions(+), 474 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index a3d256233c1..a5ad0f4ab10 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -2,12 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import List, Optional +from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages @@ -196,21 +196,29 @@ def sql(self) -> str: return self.join_type.sql @property - def individual_node_complexity(self) -> Counter[str]: + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.JOIN + + @property + def individual_node_complexity(self) -> Dict[str, int]: # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond - stat = Counter({PlanNodeCategory.JOIN.value: 1}) + score = {self.plan_node_category.value: 1} if isinstance(self.join_type, UsingJoin) and self.join_type.using_columns: - stat += Counter( - {PlanNodeCategory.COLUMN.value: len(self.join_type.using_columns)} + score = add_node_complexities( + score, + {PlanNodeCategory.COLUMN.value: len(self.join_type.using_columns)}, ) - stat += ( - self.join_condition.cumulative_node_complexity + score = ( + add_node_complexities(score, self.join_condition.cumulative_node_complexity) if self.join_condition - else Counter() + else score ) - stat += ( - self.match_condition.cumulative_node_complexity + + score = ( + add_node_complexities( + score, self.match_condition.cumulative_node_complexity + ) if self.match_condition - else Counter() + else score ) - return stat + return score diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index b674406dc7e..179bb59a9d3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -4,12 +4,12 @@ import copy import uuid -from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, AbstractSet, Any, Dict, List, Optional, Tuple import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) if TYPE_CHECKING: @@ -62,7 +62,7 @@ def __init__(self, child: Optional["Expression"] = None) -> None: self.nullable = True self.children = [child] if child else None self.datatype: Optional[DataType] = None - self._cumulative_node_complexity: Optional[Counter] = None + self._cumulative_node_complexity: Optional[Dict[str, int]] = None def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. @@ -93,21 +93,21 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.OTHERS @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: """Returns the individual contribution of the expression node towards the overall compilation complexity of the generated sql. """ - return Counter({self.plan_node_category.value: 1}) + return {self.plan_node_category.value: 1} def calculate_cumulative_node_complexity(self): children = self.children or [] - return sum( - (child.cumulative_node_complexity for child in children), + return add_node_complexities( self.individual_node_complexity, + *(child.cumulative_node_complexity for child in children), ) @property - def cumulative_node_complexity(self) -> Counter[str]: + def cumulative_node_complexity(self) -> Dict[str, int]: """Returns the aggregate sum complexity statistic from the subtree rooted at this expression node. Statistic of current node is included in the final aggregate. """ @@ -118,7 +118,7 @@ def cumulative_node_complexity(self) -> Counter[str]: return self._cumulative_node_complexity @cumulative_node_complexity.setter - def cumulative_node_complexity(self, value: Counter[str]): + def cumulative_node_complexity(self, value: Dict[str, int]): self._cumulative_node_complexity = value @@ -146,7 +146,7 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: return self.plan.cumulative_node_complexity @@ -158,10 +158,9 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return sum( - (expr.cumulative_node_complexity for expr in self.expressions), - Counter(), + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + *(expr.cumulative_node_complexity for expr in self.expressions), ) @@ -178,14 +177,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return ( - self.columns.cumulative_node_complexity - + self.individual_node_complexity - + sum( - (expr.cumulative_node_complexity for expr in self.values), - Counter(), - ) + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + self.individual_node_complexity, + self.columns.cumulative_node_complexity, + *(expr.cumulative_node_complexity for expr in self.values), ) @@ -233,16 +229,16 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: if self.expressions: - return Counter() + return {} # if there are no expressions, we assign column value = 1 to Star - return Counter({PlanNodeCategory.COLUMN.value: 1}) + return {PlanNodeCategory.COLUMN.value: 1} - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return self.individual_node_complexity + sum( - (child.individual_node_complexity for child in self.expressions), - Counter(), + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + self.individual_node_complexity, + *(child.individual_node_complexity for child in self.expressions), ) @@ -377,11 +373,11 @@ def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return ( - self.expr.cumulative_node_complexity - + self.pattern.cumulative_node_complexity - + self.individual_node_complexity + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + self.individual_node_complexity, + self.expr.cumulative_node_complexity, + self.pattern.cumulative_node_complexity, ) @@ -399,11 +395,11 @@ def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return ( - self.expr.cumulative_node_complexity - + self.pattern.cumulative_node_complexity - + self.individual_node_complexity + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + self.individual_node_complexity, + self.expr.cumulative_node_complexity, + self.pattern.cumulative_node_complexity, ) @@ -421,8 +417,10 @@ def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return self.expr.cumulative_node_complexity + self.individual_node_complexity + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + self.expr.cumulative_node_complexity, self.individual_node_complexity + ) class SubfieldString(Expression): @@ -439,9 +437,11 @@ def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field return PlanNodeCategory.LITERAL - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # self.expr ( self.field ) - return self.expr.cumulative_node_complexity + self.individual_node_complexity + return add_node_complexities( + self.expr.cumulative_node_complexity, self.individual_node_complexity + ) class SubfieldInt(Expression): @@ -458,9 +458,11 @@ def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field return PlanNodeCategory.LITERAL - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # self.expr ( self.field ) - return self.expr.cumulative_node_complexity + self.individual_node_complexity + return add_node_complexities( + self.expr.cumulative_node_complexity, self.individual_node_complexity + ) class FunctionExpression(Expression): @@ -514,14 +516,11 @@ def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) return PlanNodeCategory.ORDER_BY - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return ( - sum( - (col.cumulative_node_complexity for col in self.order_by_cols), - Counter(), - ) - + self.individual_node_complexity - + self.expr.cumulative_node_complexity + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + self.individual_node_complexity, + self.expr.cumulative_node_complexity, + *(col.cumulative_node_complexity for col in self.order_by_cols), ) @@ -547,18 +546,23 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.CASE_WHEN - def calculate_cumulative_node_complexity(self) -> Counter[str]: - stat = self.individual_node_complexity + sum( - ( - condition.cumulative_node_complexity + value.cumulative_node_complexity + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + score = add_node_complexities( + self.individual_node_complexity, + *( + add_node_complexities( + condition.cumulative_node_complexity, + value.cumulative_node_complexity, + ) for condition, value in self.branches ), - Counter(), ) - stat += ( - self.else_value.cumulative_node_complexity if self.else_value else Counter() + score = ( + add_node_complexities(score, self.else_value.cumulative_node_complexity) + if self.else_value + else score ) - return stat + return score class SnowflakeUDF(Expression): @@ -584,10 +588,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return sum( - (expr.cumulative_node_complexity for expr in self.children), + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( self.individual_node_complexity, + *(expr.cumulative_node_complexity for expr in self.children), ) @@ -605,5 +609,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return self.col.cumulative_node_complexity + self.individual_node_complexity + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + self.col.cumulative_node_complexity, self.individual_node_complexity + ) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 6b566e0ff0d..3b306378135 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -2,15 +2,15 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, List, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) @@ -45,10 +45,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return sum( - ( - sum((expr.cumulative_node_complexity for expr in arg), Counter()) + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + *( + add_node_complexities( + *(expr.cumulative_node_complexity for expr in arg) + ) for arg in self.args ), self.individual_node_complexity, diff --git a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py index 55fafb53ef7..2a62ea483ab 100644 --- a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py @@ -2,21 +2,9 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import sys +from collections import Counter from enum import Enum - -# collections.Counter does not pass type checker. Changes with appropriate type hints were made in 3.9+ -if sys.version_info < (3, 9): - import collections - import typing - - KT = typing.TypeVar("KT") - - class Counter(collections.Counter, typing.Counter[KT]): - pass - -else: - from collections import Counter # noqa +from typing import Dict class PlanNodeCategory(Enum): @@ -47,3 +35,10 @@ class PlanNodeCategory(Enum): IN = "in" LOW_IMPACT = "low_impact" OTHERS = "others" + + +def add_node_complexities(*node_complexities: Dict[str, int]) -> Dict[str, int]: + counter_sum = sum( + (Counter(complexity) for complexity in node_complexities), Counter() + ) + return dict(counter_sum) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 866ed20d9d8..4906902887e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -23,8 +23,8 @@ import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.cte_utils import encode_id from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) from snowflake.snowpark._internal.analyzer.table_function import ( TableFunctionExpression, @@ -203,7 +203,7 @@ def __init__( str, Dict[str, str] ] = defaultdict(dict) self._api_calls = api_calls.copy() if api_calls is not None else None - self._cumulative_node_complexity: Optional[Counter[str]] = None + self._cumulative_node_complexity: Optional[Dict[str, int]] = None def __eq__(self, other: "Selectable") -> bool: if self._id is not None and other._id is not None: @@ -295,16 +295,16 @@ def num_duplicate_nodes(self) -> int: return self.snowflake_plan.num_duplicate_nodes @property - def cumulative_node_complexity(self) -> Counter[str]: + def cumulative_node_complexity(self) -> Dict[str, int]: if self._cumulative_node_complexity is None: - stat = self.individual_node_complexity - for node in self.children_plan_nodes: - stat += node.cumulative_node_complexity - self._cumulative_node_complexity = stat + self._cumulative_node_complexity = add_node_complexities( + self.individual_node_complexity, + *(node.cumulative_node_complexity for node in self.children_plan_nodes), + ) return self._cumulative_node_complexity @cumulative_node_complexity.setter - def cumulative_node_complexity(self, value: Counter[str]): + def cumulative_node_complexity(self, value: Dict[str, int]): self._cumulative_node_complexity = value @property @@ -492,7 +492,7 @@ def query_params(self) -> Optional[Sequence[Any]]: return self._query_params @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: return self.snowflake_plan.individual_node_complexity @@ -689,55 +689,59 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return [self.from_] @property - def individual_node_complexity(self) -> Counter[str]: - stat = Counter() + def individual_node_complexity(self) -> Dict[str, int]: + score = {} # projection component - stat += ( - sum( - ( + score = ( + add_node_complexities( + score, + *( getattr( expr, "cumulative_node_complexity", - Counter({PlanNodeCategory.COLUMN.value: 1}), + {PlanNodeCategory.COLUMN.value: 1}, ) # type: ignore for expr in self.projection ), - Counter(), ) if self.projection - else Counter() + else score ) # filter component - add +1 for WHERE clause and sum of expression complexity for where expression - stat += ( - Counter({PlanNodeCategory.FILTER.value: 1}) - + self.where.cumulative_node_complexity + score = ( + add_node_complexities( + score, + {PlanNodeCategory.FILTER.value: 1}, + self.where.cumulative_node_complexity, + ) if self.where - else Counter() + else score ) # order by component - add complexity for each sort expression - stat += ( - sum( - (expr.cumulative_node_complexity for expr in self.order_by), - Counter({PlanNodeCategory.ORDER_BY.value: 1}), + score = ( + add_node_complexities( + score, + *(expr.cumulative_node_complexity for expr in self.order_by), + {PlanNodeCategory.ORDER_BY.value: 1}, ) if self.order_by - else Counter() + else score ) # limit/offset component - stat += ( - Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) + score = ( + add_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) if self.limit_ - else Counter() + else score ) - stat += ( - Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) + score = ( + add_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) if self.offset - else Counter() + else score ) - return stat + return score def to_subqueryable(self) -> "Selectable": """When this SelectStatement's subquery is not subqueryable (can't be used in `from` clause of the sql), @@ -1042,7 +1046,7 @@ def query_params(self) -> Optional[Sequence[Any]]: return self.snowflake_plan.queries[-1].params @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: return self.snowflake_plan.individual_node_complexity @@ -1125,11 +1129,9 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return self._nodes @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # we add #set_operands - 1 additional operators in sql query - return Counter( - {PlanNodeCategory.SET_OPERATION.value: len(self.set_operands) - 1} - ) + return {PlanNodeCategory.SET_OPERATION.value: len(self.set_operands) - 1} class DeriveColumnDependencyError(Exception): diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 0023f9aec99..34bae83277c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -21,7 +21,9 @@ Union, ) -from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import Counter +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + add_node_complexities, +) from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, TableFunctionRelation, @@ -233,7 +235,7 @@ def __init__( self.placeholder_query = placeholder_query # encode an id for CTE optimization self._id = encode_id(queries[-1].sql, queries[-1].params) - self._cumulative_node_complexity: Optional[Counter[str]] = None + self._cumulative_node_complexity: Optional[Dict[str, int]] = None def __eq__(self, other: "SnowflakePlan") -> bool: if self._id is not None and other._id is not None: @@ -353,22 +355,22 @@ def num_duplicate_nodes(self) -> int: return len(find_duplicate_subtrees(self)) @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: if self.source_plan: return self.source_plan.individual_node_complexity - return Counter() + return {} @property - def cumulative_node_complexity(self) -> Counter[str]: + def cumulative_node_complexity(self) -> Dict[str, int]: if self._cumulative_node_complexity is None: - stat = self.individual_node_complexity - for node in self.children_plan_nodes: - stat += node.cumulative_node_complexity - self._cumulative_node_complexity = stat + self._cumulative_node_complexity = add_node_complexities( + self.individual_node_complexity, + *(node.cumulative_node_complexity for node in self.children_plan_nodes), + ) return self._cumulative_node_complexity @cumulative_node_complexity.setter - def cumulative_node_complexity(self, value: Counter[str]): + def cumulative_node_complexity(self, value: Dict[str, int]): self._cumulative_node_complexity = value def __copy__(self) -> "SnowflakePlan": diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index a1f99926e49..7b028564471 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -9,8 +9,8 @@ from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType @@ -27,34 +27,33 @@ class LogicalPlan: def __init__(self) -> None: self.children = [] - self._cumulative_node_complexity: Optional[Counter[str]] = None + self._cumulative_node_complexity: Optional[Dict[str, int]] = None @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.OTHERS @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: """Returns the individual contribution of the logical plan node towards the overall compilation complexity of the generated sql. """ - return Counter({self.plan_node_category.value: 1}) + return {self.plan_node_category.value: 1} @property - def cumulative_node_complexity(self) -> Counter[str]: + def cumulative_node_complexity(self) -> Dict[str, int]: """Returns the aggregate sum complexity statistic from the subtree rooted at this logical plan node. Statistic of current node is included in the final aggregate. """ if self._cumulative_node_complexity is None: - stat = self.individual_node_complexity - for node in self.children: - stat += node.cumulative_node_complexity - - self._cumulative_node_complexity = stat + self._cumulative_node_complexity = add_node_complexities( + self.individual_node_complexity, + *(node.cumulative_node_complexity for node in self.children), + ) return self._cumulative_node_complexity @cumulative_node_complexity.setter - def cumulative_node_complexity(self, value: Counter[str]): + def cumulative_node_complexity(self, value: Dict[str, int]): self._cumulative_node_complexity = value @@ -73,17 +72,15 @@ def __init__(self, start: int, end: int, step: int, num_slices: int = 1) -> None self.num_slices = num_slices @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (step) + (start) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => count))) - return Counter( - { - PlanNodeCategory.WINDOW.value: 1, - PlanNodeCategory.ORDER_BY.value: 1, - PlanNodeCategory.LITERAL.value: 3, # step, start, count - PlanNodeCategory.COLUMN.value: 1, # id column - PlanNodeCategory.LOW_IMPACT.value: 2, # ROW_NUMBER, GENERATOR - } - ) + return { + PlanNodeCategory.WINDOW.value: 1, + PlanNodeCategory.ORDER_BY.value: 1, + PlanNodeCategory.LITERAL.value: 3, # step, start, count + PlanNodeCategory.COLUMN.value: 1, # id column + PlanNodeCategory.LOW_IMPACT.value: 2, # ROW_NUMBER, GENERATOR + } class UnresolvedRelation(LeafNode): @@ -92,9 +89,9 @@ def __init__(self, name: str) -> None: self.name = name @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # SELECT * FROM name - return Counter({PlanNodeCategory.COLUMN.value: 1}) + return {PlanNodeCategory.COLUMN.value: 1} class SnowflakeValues(LeafNode): @@ -110,15 +107,13 @@ def __init__( self.schema_query = schema_query @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) # TODO: use ARRAY_BIND_THRESHOLD - return Counter( - { - PlanNodeCategory.COLUMN.value: len(self.output), - PlanNodeCategory.LITERAL.value: len(self.data) * len(self.output), - } - ) + return { + PlanNodeCategory.COLUMN.value: len(self.output), + PlanNodeCategory.LITERAL.value: len(self.data) * len(self.output), + } class SaveMode(Enum): @@ -151,23 +146,25 @@ def __init__( self.comment = comment @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (query) - stat = Counter({PlanNodeCategory.COLUMN.value: 1}) - stat += ( - Counter({PlanNodeCategory.COLUMN.value: len(self.column_names)}) + score = {PlanNodeCategory.COLUMN.value: 1} + score = ( + add_node_complexities( + score, {PlanNodeCategory.COLUMN.value: len(self.column_names)} + ) if self.column_names - else Counter() + else score ) - stat += ( - sum( - (expr.cumulative_node_complexity for expr in self.clustering_exprs), - Counter(), + score = ( + add_node_complexities( + score, + *(expr.cumulative_node_complexity for expr in self.clustering_exprs), ) if self.clustering_exprs - else Counter() + else score ) - return stat + return score class Limit(LogicalPlan): @@ -181,12 +178,12 @@ def __init__( self.children.append(child) @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # for limit and offset - return ( - Counter({PlanNodeCategory.LOW_IMPACT.value: 2}) - + self.limit_expr.cumulative_node_complexity - + self.offset_expr.cumulative_node_complexity + return add_node_complexities( + {PlanNodeCategory.LOW_IMPACT.value: 2}, + self.limit_expr.cumulative_node_complexity, + self.offset_expr.cumulative_node_complexity, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index aeb20880530..fa7357ffeb4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,13 +2,15 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional, Type +from typing import AbstractSet, Dict, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) -from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import Counter +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + add_node_complexities, +) class NullOrdering: @@ -57,5 +59,7 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return self.child.cumulative_node_complexity + self.individual_node_complexity + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + self.child.cumulative_node_complexity, self.individual_node_complexity + ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 19e34ae33bd..14704fb65e3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -7,8 +7,8 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -34,27 +34,29 @@ def __init__( self.partition_spec = partition_spec self.order_spec = order_spec - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: if not self.over: - return Counter() - stat = Counter({PlanNodeCategory.WINDOW.value: 1}) - stat += ( - sum( - (expr.cumulative_node_complexity for expr in self.partition_spec), - Counter({PlanNodeCategory.PARTITION_BY.value: 1}), + return {} + score = {PlanNodeCategory.WINDOW.value: 1} + score = ( + add_node_complexities( + score, + *(expr.cumulative_node_complexity for expr in self.partition_spec), + {PlanNodeCategory.PARTITION_BY.value: 1}, ) if self.partition_spec - else Counter() + else score ) - stat += ( - sum( - (expr.cumulative_node_complexity for expr in self.order_spec), - Counter({PlanNodeCategory.ORDER_BY.value: 1}), + score = ( + add_node_complexities( + score, + *(expr.cumulative_node_complexity for expr in self.order_spec), + {PlanNodeCategory.ORDER_BY.value: 1}, ) if self.order_spec - else Counter() + else score ) - return stat + return score class TableFunctionExpression(Expression): @@ -87,8 +89,10 @@ def __init__( self.recursive = recursive self.mode = mode - def calculate_cumulative_node_complexity(self) -> Counter[str]: - return self.individual_node_complexity + self.input.cumulative_node_complexity + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + return add_node_complexities( + self.individual_node_complexity, self.input.cumulative_node_complexity + ) class PosArgumentsTableFunction(TableFunctionExpression): @@ -101,17 +105,17 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args - def calculate_cumulative_node_complexity(self) -> Counter[str]: - stat = sum( - (arg.cumulative_node_complexity for arg in self.args), + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + score = add_node_complexities( + *(arg.cumulative_node_complexity for arg in self.args), self.individual_node_complexity, ) - stat += ( - self.partition_spec.cumulative_node_complexity + score = ( + add_node_complexities(score, self.partition_spec.cumulative_node_complexity) if self.partition_spec - else Counter() + else score ) - return stat + return score class NamedArgumentsTableFunction(TableFunctionExpression): @@ -124,17 +128,17 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args - def calculate_cumulative_node_complexity(self) -> Counter[str]: - stat = sum( - (arg.cumulative_node_complexity for arg in self.args.values()), + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + score = add_node_complexities( + *(arg.cumulative_node_complexity for arg in self.args.values()), self.individual_node_complexity, ) - stat += ( - self.partition_spec.cumulative_node_complexity + score = ( + add_node_complexities(score, self.partition_spec.cumulative_node_complexity) if self.partition_spec - else Counter() + else score ) - return stat + return score class GeneratorTableFunction(TableFunctionExpression): @@ -143,18 +147,20 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: self.args = args self.operators = operators - def calculate_cumulative_node_complexity(self) -> Counter[str]: - stat = sum( - (arg.cumulative_node_complexity for arg in self.args.values()), + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + score = add_node_complexities( + *(arg.cumulative_node_complexity for arg in self.args.values()), self.individual_node_complexity, ) - stat += ( - self.partition_spec.cumulative_node_complexity + score = ( + add_node_complexities(score, self.partition_spec.cumulative_node_complexity) if self.partition_spec - else Counter() + else score + ) + score = add_node_complexities( + score, {PlanNodeCategory.COLUMN.value: len(self.operators)} ) - stat += Counter({PlanNodeCategory.COLUMN.value: len(self.operators)}) - return stat + return score class TableFunctionRelation(LogicalPlan): @@ -163,7 +169,7 @@ def __init__(self, table_function: TableFunctionExpression) -> None: self.table_function = table_function @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # SELECT * FROM table_function return self.table_function.cumulative_node_complexity @@ -183,17 +189,15 @@ def __init__( self.right_cols = right_cols if right_cols is not None else ["*"] @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # SELECT left_cols, right_cols FROM child as left_alias JOIN table(func(...)) as right_alias - return ( - Counter( - { - PlanNodeCategory.COLUMN.value: len(self.left_cols) - + len(self.right_cols), - PlanNodeCategory.JOIN.value: 1, - } - ) - + self.table_function.cumulative_node_complexity + return add_node_complexities( + { + PlanNodeCategory.COLUMN.value: len(self.left_cols) + + len(self.right_cols), + PlanNodeCategory.JOIN.value: 1, + }, + self.table_function.cumulative_node_complexity, ) @@ -206,9 +210,9 @@ def __init__( self.table_function = table_function @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # SELECT * FROM (child), LATERAL table_func_expression - return ( - Counter({PlanNodeCategory.COLUMN.value: 1}) - + self.table_function.cumulative_node_complexity + return add_node_complexities( + {PlanNodeCategory.COLUMN.value: 1}, + self.table_function.cumulative_node_complexity, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index c362abdd0ad..418e6e5d22f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -6,8 +6,8 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( LogicalPlan, @@ -24,13 +24,15 @@ def __init__(self, condition: Optional[Expression]) -> None: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # WHEN MATCHED [AND condition] THEN DEL - stat = self.individual_node_complexity - stat += ( - self.condition.cumulative_node_complexity if self.condition else Counter() + score = self.individual_node_complexity + score = ( + add_node_complexities(score, self.condition.cumulative_node_complexity) + if self.condition + else score ) - return stat + return score class UpdateMergeExpression(MergeExpression): @@ -40,21 +42,24 @@ def __init__( super().__init__(condition) self.assignments = assignments - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) - stat = self.individual_node_complexity - stat += ( - self.condition.cumulative_node_complexity if self.condition else Counter() - ) - stat += sum( - ( - key_expr.cumulative_node_complexity - + val_expr.cumulative_node_complexity + score = add_node_complexities( + self.individual_node_complexity, + *( + add_node_complexities( + key_expr.cumulative_node_complexity, + val_expr.cumulative_node_complexity, + ) for key_expr, val_expr in self.assignments.items() ), - Counter(), ) - return stat + score = ( + add_node_complexities(score, self.condition.cumulative_node_complexity) + if self.condition + else score + ) + return score class DeleteMergeExpression(MergeExpression): @@ -72,15 +77,19 @@ def __init__( self.keys = keys self.values = values - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) - stat = self.individual_node_complexity - stat += ( - self.condition.cumulative_node_complexity if self.condition else Counter() + score = add_node_complexities( + self.individual_node_complexity, + *(key.cumulative_node_complexity for key in self.keys), + *(val.cumulative_node_complexity for val in self.values), ) - stat += sum((key.cumulative_node_complexity for key in self.keys), Counter()) - stat += sum((val.cumulative_node_complexity for val in self.values), Counter()) - return stat + score = ( + add_node_complexities(score, self.condition.cumulative_node_complexity) + if self.condition + else score + ) + return score class TableUpdate(LogicalPlan): @@ -99,19 +108,22 @@ def __init__( self.children = [source_data] if source_data else [] @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # UPDATE table_name SET COMMA.join(k, v in assignments) [source_data] [WHERE condition] - stat = sum( - ( - k.cumulative_node_complexity + v.cumulative_node_complexity + score = add_node_complexities( + *( + add_node_complexities( + k.cumulative_node_complexity, v.cumulative_node_complexity + ) for k, v in self.assignments.items() ), - Counter(), ) - stat += ( - self.condition.cumulative_node_complexity if self.condition else Counter() + score = ( + add_node_complexities(score, self.condition.cumulative_node_complexity) + if self.condition + else score ) - return stat + return score class TableDelete(LogicalPlan): @@ -128,11 +140,9 @@ def __init__( self.children = [source_data] if source_data else [] @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # DELETE FROM table_name [USING source_data] [WHERE condition] - return ( - self.condition.cumulative_node_complexity if self.condition else Counter() - ) + return self.condition.cumulative_node_complexity if self.condition else {} class TableMerge(LogicalPlan): @@ -151,8 +161,9 @@ def __init__( self.children = [source] if source else [] @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # MERGE INTO table_name USING (source) ON join_expr clauses - return self.join_expr.cumulative_node_complexity + sum( - (clause.cumulative_node_complexity for clause in self.clauses), Counter() + return add_node_complexities( + self.join_expr.cumulative_node_complexity, + *(clause.cumulative_node_complexity for clause in self.clauses), ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e6399df43ba..afc176a0f41 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from typing import AbstractSet, Dict, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, @@ -10,7 +10,6 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, ) from snowflake.snowpark.types import DataType @@ -103,6 +102,6 @@ def __init__(self, child: Expression) -> None: self.name = child.sql @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # this is a wrapper around child - return Counter() + return {} diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 47e7f40ea6d..335dafb45f2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -10,8 +10,8 @@ ScalarSubquery, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -38,16 +38,14 @@ def __init__( self.seed = seed @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # SELECT * FROM (child) SAMPLE (probability) -- if probability is provided # SELECT * FROM (child) SAMPLE (row_count ROWS) -- if not probability but row count is provided - return Counter( - { - PlanNodeCategory.SAMPLE.value: 1, - PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.COLUMN.value: 1, - } - ) + return { + PlanNodeCategory.SAMPLE.value: 1, + PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.COLUMN.value: 1, + } class Sort(UnaryNode): @@ -56,10 +54,11 @@ def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None: self.order = order @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # child ORDER BY COMMA.join(order) - return Counter({PlanNodeCategory.ORDER_BY.value: 1}) + sum( - (col.cumulative_node_complexity for col in self.order), Counter() + return add_node_complexities( + {PlanNodeCategory.ORDER_BY.value: 1}, + *(col.cumulative_node_complexity for col in self.order), ) @@ -75,30 +74,32 @@ def __init__( self.aggregate_expressions = aggregate_expressions @property - def individual_node_complexity(self) -> Counter[str]: - stat = Counter() + def individual_node_complexity(self) -> Dict[str, int]: if self.grouping_expressions: # GROUP BY grouping_exprs - stat += Counter({PlanNodeCategory.GROUP_BY.value: 1}) + sum( - (expr.cumulative_node_complexity for expr in self.grouping_expressions), - Counter(), + score = add_node_complexities( + {PlanNodeCategory.GROUP_BY.value: 1}, + *( + expr.cumulative_node_complexity + for expr in self.grouping_expressions + ), ) else: # LIMIT 1 - stat += Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) + score = {PlanNodeCategory.LOW_IMPACT.value: 1} - stat += sum( - ( + score = add_node_complexities( + score, + *( getattr( expr, "cumulative_node_complexity", - Counter({PlanNodeCategory.COLUMN.value: 1}), + {PlanNodeCategory.COLUMN.value: 1}, ) # type: ignore for expr in self.aggregate_expressions ), - Counter(), ) - return stat + return score class Pivot(UnaryNode): @@ -119,39 +120,41 @@ def __init__( self.default_on_null = default_on_null @property - def individual_node_complexity(self) -> Counter[str]: - stat = Counter() - # child stat adjustment if grouping cols + def individual_node_complexity(self) -> Dict[str, int]: + score = {} + # child score adjustment if grouping cols if self.grouping_columns and self.aggregates and self.aggregates[0].children: # for additional projecting cols when grouping cols is not empty - stat += sum( - (col.cumulative_node_complexity for col in self.grouping_columns), - Counter(), + score = add_node_complexities( + self.pivot_column.cumulative_node_complexity, + self.aggregates[0].children[0].cumulative_node_complexity, + *(col.cumulative_node_complexity for col in self.grouping_columns), ) - stat += self.pivot_column.cumulative_node_complexity - stat += self.aggregates[0].children[0].cumulative_node_complexity # pivot col if isinstance(self.pivot_values, ScalarSubquery): - stat += self.pivot_values.cumulative_node_complexity + score = add_node_complexities( + score, self.pivot_values.cumulative_node_complexity + ) elif isinstance(self.pivot_values, List): - stat += sum( - (val.cumulative_node_complexity for val in self.pivot_values), Counter() + score = add_node_complexities( + score, *(val.cumulative_node_complexity for val in self.pivot_values) ) else: # if pivot values is None, then we add OTHERS for ANY - stat += Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) + score = add_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) - # aggregate stat - stat += sum( - (expr.cumulative_node_complexity for expr in self.aggregates), Counter() + # aggregate score + score = add_node_complexities( + score, + *(expr.cumulative_node_complexity for expr in self.aggregates), ) # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) - stat += Counter( - {PlanNodeCategory.COLUMN.value: 2, PlanNodeCategory.PIVOT.value: 1} + score = add_node_complexities( + score, {PlanNodeCategory.COLUMN.value: 2, PlanNodeCategory.PIVOT.value: 1} ) - return stat + return score class Unpivot(UnaryNode): @@ -168,15 +171,12 @@ def __init__( self.column_list = column_list @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) - stat = Counter( - {PlanNodeCategory.UNPIVOT.value: 1, PlanNodeCategory.COLUMN.value: 3} + return add_node_complexities( + {PlanNodeCategory.UNPIVOT.value: 1, PlanNodeCategory.COLUMN.value: 3}, + *(expr.cumulative_node_complexity for expr in self.column_list), ) - stat += sum( - (expr.cumulative_node_complexity for expr in self.column_list), Counter() - ) - return stat class Rename(UnaryNode): @@ -189,14 +189,12 @@ def __init__( self.column_map = column_map @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # SELECT * RENAME (before AS after, ...) FROM child - return Counter( - { - PlanNodeCategory.COLUMN.value: 1 + 2 * len(self.column_map), - PlanNodeCategory.LOW_IMPACT.value: 1 + len(self.column_map), - } - ) + return { + PlanNodeCategory.COLUMN.value: 1 + 2 * len(self.column_map), + PlanNodeCategory.LOW_IMPACT.value: 1 + len(self.column_map), + } class Filter(UnaryNode): @@ -205,11 +203,11 @@ def __init__(self, condition: Expression, child: LogicalPlan) -> None: self.condition = condition @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # child WHERE condition - return ( - Counter({PlanNodeCategory.FILTER.value: 1}) - + self.condition.cumulative_node_complexity + return add_node_complexities( + {PlanNodeCategory.FILTER.value: 1}, + self.condition.cumulative_node_complexity, ) @@ -219,20 +217,19 @@ def __init__(self, project_list: List[NamedExpression], child: LogicalPlan) -> N self.project_list = project_list @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: if not self.project_list: - return Counter({PlanNodeCategory.COLUMN.value: 1}) + return {PlanNodeCategory.COLUMN.value: 1} - return sum( - ( + return add_node_complexities( + *( getattr( col, "cumulative_node_complexity", - Counter({PlanNodeCategory.COLUMN.value: 1}), + {PlanNodeCategory.COLUMN.value: 1}, ) # type: ignore for col in self.project_list ), - Counter(), ) diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index a2051355fd7..a148632d1d5 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -2,15 +2,15 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, List, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -75,12 +75,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # frame_type BETWEEN lower AND upper - return ( - self.individual_node_complexity - + self.lower.cumulative_node_complexity - + self.upper.cumulative_node_complexity + return add_node_complexities( + self.individual_node_complexity, + self.lower.cumulative_node_complexity, + self.upper.cumulative_node_complexity, ) @@ -102,32 +102,27 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: ) @property - def individual_node_complexity(self) -> Counter[str]: - stat = Counter() - stat += ( - Counter({PlanNodeCategory.PARTITION_BY.value: 1}) + def individual_node_complexity(self) -> Dict[str, int]: + score = {} + score = ( + add_node_complexities(score, {PlanNodeCategory.PARTITION_BY.value: 1}) if self.partition_spec - else Counter() + else score ) - stat += ( - Counter({PlanNodeCategory.ORDER_BY.value: 1}) + score = ( + add_node_complexities(score, {PlanNodeCategory.ORDER_BY.value: 1}) if self.order_spec - else Counter() + else score ) - return stat + return score - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # partition_spec order_by_spec frame_spec - return ( - self.individual_node_complexity - + sum( - (expr.cumulative_node_complexity for expr in self.partition_spec), - Counter(), - ) - + sum( - (expr.cumulative_node_complexity for expr in self.order_spec), Counter() - ) - + self.frame_spec.cumulative_node_complexity + return add_node_complexities( + self.individual_node_complexity, + self.frame_spec.cumulative_node_complexity, + *(expr.cumulative_node_complexity for expr in self.partition_spec), + *(expr.cumulative_node_complexity for expr in self.order_spec), ) @@ -146,12 +141,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # window_function OVER ( window_spec ) - return ( - self.window_function.cumulative_node_complexity - + self.window_spec.cumulative_node_complexity - + self.individual_node_complexity + return add_node_complexities( + self.window_function.cumulative_node_complexity, + self.window_spec.cumulative_node_complexity, + self.individual_node_complexity, ) @@ -175,26 +170,35 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) @property - def individual_node_complexity(self) -> Counter[str]: + def individual_node_complexity(self) -> Dict[str, int]: # for func_name - stat = Counter({PlanNodeCategory.FUNCTION.value: 1}) + score = {PlanNodeCategory.FUNCTION.value: 1} # for offset - stat += ( - Counter({PlanNodeCategory.LITERAL.value: 1}) if self.offset else Counter() + score = ( + add_node_complexities(score, {PlanNodeCategory.LITERAL.value: 1}) + if self.offset + else score ) + # for ignore nulls - stat += ( - Counter({PlanNodeCategory.LOW_IMPACT.value: 1}) + score = ( + add_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) if self.ignore_nulls - else Counter() + else score ) - return stat + return score - def calculate_cumulative_node_complexity(self) -> Counter[str]: + def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # func_name (expr [, offset] [, default]) [IGNORE NULLS] - stat = self.individual_node_complexity + self.expr.cumulative_node_complexity - stat += self.default.cumulative_node_complexity if self.default else Counter() - return stat + score = add_node_complexities( + self.individual_node_complexity, self.expr.cumulative_node_complexity + ) + score = ( + add_node_complexities(score, self.default.cumulative_node_complexity) + if self.default + else score + ) + return score class Lag(RankRelatedFunctionExpression): diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index dae7b5338a1..133ea5c5917 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -3,11 +3,13 @@ # +from typing import Dict + import pytest from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - Counter, PlanNodeCategory, + add_node_complexities, ) from snowflake.snowpark._internal.analyzer.select_statement import ( SET_EXCEPT, @@ -51,11 +53,11 @@ def sample_table(session): Utils.drop_table(session, table_name) -def get_cumulative_node_complexity(df: DataFrame) -> Counter[str]: +def get_cumulative_node_complexity(df: DataFrame) -> Dict[str, int]: return df._plan.cumulative_node_complexity -def assert_df_subtree_query_complexity(df: DataFrame, estimate: Counter[str]): +def assert_df_subtree_query_complexity(df: DataFrame, estimate: Dict[str, int]): assert ( get_cumulative_node_complexity(df) == estimate ), f"query = {df.queries['queries'][-1]}" @@ -113,13 +115,13 @@ def test_generator_table_function(session: Session): # adds SELECT * from () ORDER BY seq ASC NULLS FIRST assert_df_subtree_query_complexity( df2, - get_cumulative_node_complexity(df1) - + Counter( + add_node_complexities( + get_cumulative_node_complexity(df1), { PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1, - } + }, ), ) @@ -258,18 +260,16 @@ def test_window_function(session: Session): # SELECT avg("VALUE") OVER (PARTITION BY "VALUE" ORDER BY "KEY" ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING ) AS "WINDOW1" FROM table_name assert_df_subtree_query_complexity( df1, - Counter( - { - PlanNodeCategory.PARTITION_BY.value: 1, - PlanNodeCategory.ORDER_BY.value: 1, - PlanNodeCategory.WINDOW.value: 1, - PlanNodeCategory.FUNCTION.value: 1, - PlanNodeCategory.COLUMN.value: 5, - PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.LOW_IMPACT.value: 2, - PlanNodeCategory.OTHERS.value: 1, - } - ), + { + PlanNodeCategory.PARTITION_BY.value: 1, + PlanNodeCategory.ORDER_BY.value: 1, + PlanNodeCategory.WINDOW.value: 1, + PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.COLUMN.value: 5, + PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.LOW_IMPACT.value: 2, + PlanNodeCategory.OTHERS.value: 1, + }, ) # SELECT avg("VALUE") OVER ( ORDER BY "KEY" DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS "WINDOW2" FROM ( @@ -277,8 +277,8 @@ def test_window_function(session: Session): df2 = df1.select(avg("value").over(window2).as_("window2")) assert_df_subtree_query_complexity( df2, - get_cumulative_node_complexity(df1) - + Counter( + add_node_complexities( + get_cumulative_node_complexity(df1), { PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.WINDOW.value: 1, @@ -286,7 +286,7 @@ def test_window_function(session: Session): PlanNodeCategory.COLUMN.value: 3, PlanNodeCategory.LOW_IMPACT.value: 3, PlanNodeCategory.OTHERS.value: 1, - } + }, ), ) finally: @@ -321,9 +321,9 @@ def test_join_statement(session: Session, sample_table: str): # SELECT * FROM ((ch1) AS SNOWPARK_LEFT INNER JOIN ( ch2) AS SNOWPARK_RIGHT ON (("l_k7b8_A" = "r_e09m_A") AND ("l_k7b8_B" = "r_e09m_B"))) assert_df_subtree_query_complexity( df4, - get_cumulative_node_complexity(df3) - + Counter( - {PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.LOW_IMPACT.value: 3} + add_node_complexities( + get_cumulative_node_complexity(df3), + {PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.LOW_IMPACT.value: 3}, ), ) @@ -331,8 +331,9 @@ def test_join_statement(session: Session, sample_table: str): # SELECT * FROM ( (ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT USING (a, b)) assert_df_subtree_query_complexity( df5, - get_cumulative_node_complexity(df3) - + Counter({PlanNodeCategory.COLUMN.value: 2}), + add_node_complexities( + get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN.value: 2} + ), ) @@ -453,13 +454,13 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df4 = df3.sort(col("b").asc()) assert_df_subtree_query_complexity( df4, - get_cumulative_node_complexity(df3) - + Counter( + add_node_complexities( + get_cumulative_node_complexity(df3), { PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.ORDER_BY.value: 1, PlanNodeCategory.OTHERS.value: 1, - } + }, ), ) @@ -467,8 +468,10 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df5 = df4.sort(col("c").desc()) assert_df_subtree_query_complexity( df5, - get_cumulative_node_complexity(df4) - + Counter({PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}), + add_node_complexities( + get_cumulative_node_complexity(df4), + {PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}, + ), ) # add filter @@ -476,14 +479,14 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df6 = df5.filter(col("b") > 2) assert_df_subtree_query_complexity( df6, - get_cumulative_node_complexity(df5) - + Counter( + add_node_complexities( + get_cumulative_node_complexity(df5), { PlanNodeCategory.FILTER.value: 1, PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.LITERAL.value: 1, PlanNodeCategory.LOW_IMPACT.value: 1, - } + }, ), ) @@ -491,13 +494,13 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df7 = df6.filter(col("c") > 3) assert_df_subtree_query_complexity( df7, - get_cumulative_node_complexity(df6) - + Counter( + add_node_complexities( + get_cumulative_node_complexity(df6), { PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.LITERAL.value: 1, PlanNodeCategory.LOW_IMPACT.value: 2, - } + }, ), ) @@ -505,9 +508,9 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df8 = df3.union_all(df4).union_all(df5) assert_df_subtree_query_complexity( df8, - sum( - (get_cumulative_node_complexity(df) for df in [df3, df4, df5]), - Counter({PlanNodeCategory.SET_OPERATION.value: 2}), + add_node_complexities( + *(get_cumulative_node_complexity(df) for df in [df3, df4, df5]), + {PlanNodeCategory.SET_OPERATION.value: 2}, ), ) @@ -515,9 +518,9 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df9 = df8.union_all(df6).union_all(df7) assert_df_subtree_query_complexity( df9, - sum( - (get_cumulative_node_complexity(df) for df in [df6, df7, df8]), - Counter({PlanNodeCategory.SET_OPERATION.value: 2}), + add_node_complexities( + *(get_cumulative_node_complexity(df) for df in [df6, df7, df8]), + {PlanNodeCategory.SET_OPERATION.value: 2}, ), ) @@ -525,14 +528,16 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df10 = df9.limit(2) assert_df_subtree_query_complexity( df10, - get_cumulative_node_complexity(df9) - + Counter({PlanNodeCategory.LOW_IMPACT.value: 1}), + add_node_complexities( + get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT.value: 1} + ), ) # for offset df11 = df9.limit(3, offset=1) assert_df_subtree_query_complexity( df11, - get_cumulative_node_complexity(df9) - + Counter({PlanNodeCategory.LOW_IMPACT.value: 2}), + add_node_complexities( + get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT.value: 2} + ), ) diff --git a/tests/unit/test_query_plan_analysis.py b/tests/unit/test_query_plan_analysis.py index 73603b06cd9..cfd3e94c3c2 100644 --- a/tests/unit/test_query_plan_analysis.py +++ b/tests/unit/test_query_plan_analysis.py @@ -2,7 +2,6 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from collections import Counter from unittest import mock import pytest @@ -13,11 +12,7 @@ UNION, UNION_ALL, ) -from snowflake.snowpark._internal.analyzer.expression import ( - Attribute, - Expression, - NamedExpression, -) +from snowflake.snowpark._internal.analyzer.expression import Expression, NamedExpression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, ) @@ -35,7 +30,6 @@ from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.analyzer.table_function import TableFunctionExpression from snowflake.snowpark._internal.analyzer.unary_plan_node import Project -from snowflake.snowpark.types import IntegerType @pytest.mark.parametrize("node_type", [LogicalPlan, SnowflakePlan, Selectable]) @@ -91,7 +85,7 @@ def set_children(node, node_type, children): assert nodes[5].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} assert nodes[6].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} - nodes[1].cumulative_node_complexity = Counter({PlanNodeCategory.COLUMN.value: 1}) + nodes[1].cumulative_node_complexity = {PlanNodeCategory.COLUMN.value: 1} # assert that only value that is reset is changed assert nodes[0].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 7} @@ -110,22 +104,8 @@ def test_select_sql_individual_node_complexity(mock_session, mock_analyzer): ) assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN.value: 1} - def mocked_get_result_attributes(sql): - return [Attribute("A", IntegerType()), Attribute("B", IntegerType())] - - def mocked_analyze( - attr: Attribute, df_aliased_col_name_to_real_col_name, parse_local_name - ): - return attr.name - - with mock.patch.object( - mock_session, "_get_result_attributes", side_effect=mocked_get_result_attributes - ): - with mock.patch.object(mock_analyzer, "analyze", side_effect=mocked_analyze): - plan_node = SelectSQL("select 1 as A, 2 as B", analyzer=mock_analyzer) - assert plan_node.individual_node_complexity == { - PlanNodeCategory.COLUMN.value: 2 - } + plan_node = SelectSQL("select 1 as A, 2 as B", analyzer=mock_analyzer) + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN.value: 1} def test_select_snowflake_plan_individual_node_complexity( From 30088783d94d78a5b923897c3777f493e27767a1 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 12 Jun 2024 14:49:02 -0700 Subject: [PATCH 32/37] rename dict add function --- .../_internal/analyzer/binary_plan_node.py | 8 ++--- .../snowpark/_internal/analyzer/expression.py | 32 +++++++++---------- .../_internal/analyzer/grouping_set.py | 6 ++-- .../analyzer/query_plan_analysis_utils.py | 2 +- .../_internal/analyzer/select_statement.py | 14 ++++---- .../_internal/analyzer/snowflake_plan.py | 4 +-- .../_internal/analyzer/snowflake_plan_node.py | 10 +++--- .../_internal/analyzer/sort_expression.py | 4 +-- .../_internal/analyzer/table_function.py | 26 +++++++-------- .../analyzer/table_merge_expression.py | 22 ++++++------- .../_internal/analyzer/unary_plan_node.py | 26 +++++++-------- .../_internal/analyzer/window_expression.py | 20 ++++++------ tests/integ/test_query_plan_analysis.py | 26 +++++++-------- 13 files changed, 100 insertions(+), 100 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index a5ad0f4ab10..a25a99a7804 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -7,7 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages @@ -204,18 +204,18 @@ def individual_node_complexity(self) -> Dict[str, int]: # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond score = {self.plan_node_category.value: 1} if isinstance(self.join_type, UsingJoin) and self.join_type.using_columns: - score = add_node_complexities( + score = sum_node_complexities( score, {PlanNodeCategory.COLUMN.value: len(self.join_type.using_columns)}, ) score = ( - add_node_complexities(score, self.join_condition.cumulative_node_complexity) + sum_node_complexities(score, self.join_condition.cumulative_node_complexity) if self.join_condition else score ) score = ( - add_node_complexities( + sum_node_complexities( score, self.match_condition.cumulative_node_complexity ) if self.match_condition diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 179bb59a9d3..9dc7fd657c4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -9,7 +9,7 @@ import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) if TYPE_CHECKING: @@ -101,7 +101,7 @@ def individual_node_complexity(self) -> Dict[str, int]: def calculate_cumulative_node_complexity(self): children = self.children or [] - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, *(child.cumulative_node_complexity for child in children), ) @@ -159,7 +159,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( *(expr.cumulative_node_complexity for expr in self.expressions), ) @@ -178,7 +178,7 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, self.columns.cumulative_node_complexity, *(expr.cumulative_node_complexity for expr in self.values), @@ -236,7 +236,7 @@ def individual_node_complexity(self) -> Dict[str, int]: return {PlanNodeCategory.COLUMN.value: 1} def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, *(child.individual_node_complexity for child in self.expressions), ) @@ -374,7 +374,7 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, self.expr.cumulative_node_complexity, self.pattern.cumulative_node_complexity, @@ -396,7 +396,7 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, self.expr.cumulative_node_complexity, self.pattern.cumulative_node_complexity, @@ -418,7 +418,7 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.expr.cumulative_node_complexity, self.individual_node_complexity ) @@ -439,7 +439,7 @@ def plan_node_category(self) -> PlanNodeCategory: def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # self.expr ( self.field ) - return add_node_complexities( + return sum_node_complexities( self.expr.cumulative_node_complexity, self.individual_node_complexity ) @@ -460,7 +460,7 @@ def plan_node_category(self) -> PlanNodeCategory: def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # self.expr ( self.field ) - return add_node_complexities( + return sum_node_complexities( self.expr.cumulative_node_complexity, self.individual_node_complexity ) @@ -517,7 +517,7 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.ORDER_BY def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, self.expr.cumulative_node_complexity, *(col.cumulative_node_complexity for col in self.order_by_cols), @@ -547,10 +547,10 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.CASE_WHEN def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - score = add_node_complexities( + score = sum_node_complexities( self.individual_node_complexity, *( - add_node_complexities( + sum_node_complexities( condition.cumulative_node_complexity, value.cumulative_node_complexity, ) @@ -558,7 +558,7 @@ def calculate_cumulative_node_complexity(self) -> Dict[str, int]: ), ) score = ( - add_node_complexities(score, self.else_value.cumulative_node_complexity) + sum_node_complexities(score, self.else_value.cumulative_node_complexity) if self.else_value else score ) @@ -589,7 +589,7 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, *(expr.cumulative_node_complexity for expr in self.children), ) @@ -610,6 +610,6 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.col.cumulative_node_complexity, self.individual_node_complexity ) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 3b306378135..91b170a5dbd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -10,7 +10,7 @@ ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) @@ -46,9 +46,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*flattened_args) def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( *( - add_node_complexities( + sum_node_complexities( *(expr.cumulative_node_complexity for expr in arg) ) for arg in self.args diff --git a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py index 2a62ea483ab..69a39586c2b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py @@ -37,7 +37,7 @@ class PlanNodeCategory(Enum): OTHERS = "others" -def add_node_complexities(*node_complexities: Dict[str, int]) -> Dict[str, int]: +def sum_node_complexities(*node_complexities: Dict[str, int]) -> Dict[str, int]: counter_sum = sum( (Counter(complexity) for complexity in node_complexities), Counter() ) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 4906902887e..f4ad87f737c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -24,7 +24,7 @@ from snowflake.snowpark._internal.analyzer.cte_utils import encode_id from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.table_function import ( TableFunctionExpression, @@ -297,7 +297,7 @@ def num_duplicate_nodes(self) -> int: @property def cumulative_node_complexity(self) -> Dict[str, int]: if self._cumulative_node_complexity is None: - self._cumulative_node_complexity = add_node_complexities( + self._cumulative_node_complexity = sum_node_complexities( self.individual_node_complexity, *(node.cumulative_node_complexity for node in self.children_plan_nodes), ) @@ -693,7 +693,7 @@ def individual_node_complexity(self) -> Dict[str, int]: score = {} # projection component score = ( - add_node_complexities( + sum_node_complexities( score, *( getattr( @@ -710,7 +710,7 @@ def individual_node_complexity(self) -> Dict[str, int]: # filter component - add +1 for WHERE clause and sum of expression complexity for where expression score = ( - add_node_complexities( + sum_node_complexities( score, {PlanNodeCategory.FILTER.value: 1}, self.where.cumulative_node_complexity, @@ -721,7 +721,7 @@ def individual_node_complexity(self) -> Dict[str, int]: # order by component - add complexity for each sort expression score = ( - add_node_complexities( + sum_node_complexities( score, *(expr.cumulative_node_complexity for expr in self.order_by), {PlanNodeCategory.ORDER_BY.value: 1}, @@ -732,12 +732,12 @@ def individual_node_complexity(self) -> Dict[str, int]: # limit/offset component score = ( - add_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) + sum_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) if self.limit_ else score ) score = ( - add_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) + sum_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) if self.offset else score ) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 34bae83277c..968d64e6208 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -22,7 +22,7 @@ ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - add_node_complexities, + sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, @@ -363,7 +363,7 @@ def individual_node_complexity(self) -> Dict[str, int]: @property def cumulative_node_complexity(self) -> Dict[str, int]: if self._cumulative_node_complexity is None: - self._cumulative_node_complexity = add_node_complexities( + self._cumulative_node_complexity = sum_node_complexities( self.individual_node_complexity, *(node.cumulative_node_complexity for node in self.children_plan_nodes), ) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 7b028564471..d52b0c9fe54 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -10,7 +10,7 @@ from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType @@ -46,7 +46,7 @@ def cumulative_node_complexity(self) -> Dict[str, int]: logical plan node. Statistic of current node is included in the final aggregate. """ if self._cumulative_node_complexity is None: - self._cumulative_node_complexity = add_node_complexities( + self._cumulative_node_complexity = sum_node_complexities( self.individual_node_complexity, *(node.cumulative_node_complexity for node in self.children), ) @@ -150,14 +150,14 @@ def individual_node_complexity(self) -> Dict[str, int]: # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (query) score = {PlanNodeCategory.COLUMN.value: 1} score = ( - add_node_complexities( + sum_node_complexities( score, {PlanNodeCategory.COLUMN.value: len(self.column_names)} ) if self.column_names else score ) score = ( - add_node_complexities( + sum_node_complexities( score, *(expr.cumulative_node_complexity for expr in self.clustering_exprs), ) @@ -180,7 +180,7 @@ def __init__( @property def individual_node_complexity(self) -> Dict[str, int]: # for limit and offset - return add_node_complexities( + return sum_node_complexities( {PlanNodeCategory.LOW_IMPACT.value: 2}, self.limit_expr.cumulative_node_complexity, self.offset_expr.cumulative_node_complexity, diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index fa7357ffeb4..8e8989e5e1d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -9,7 +9,7 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - add_node_complexities, + sum_node_complexities, ) @@ -60,6 +60,6 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.child.cumulative_node_complexity, self.individual_node_complexity ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 14704fb65e3..4144eed4989 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -8,7 +8,7 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -39,7 +39,7 @@ def calculate_cumulative_node_complexity(self) -> Dict[str, int]: return {} score = {PlanNodeCategory.WINDOW.value: 1} score = ( - add_node_complexities( + sum_node_complexities( score, *(expr.cumulative_node_complexity for expr in self.partition_spec), {PlanNodeCategory.PARTITION_BY.value: 1}, @@ -48,7 +48,7 @@ def calculate_cumulative_node_complexity(self) -> Dict[str, int]: else score ) score = ( - add_node_complexities( + sum_node_complexities( score, *(expr.cumulative_node_complexity for expr in self.order_spec), {PlanNodeCategory.ORDER_BY.value: 1}, @@ -90,7 +90,7 @@ def __init__( self.mode = mode def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, self.input.cumulative_node_complexity ) @@ -106,12 +106,12 @@ def __init__( self.args = args def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - score = add_node_complexities( + score = sum_node_complexities( *(arg.cumulative_node_complexity for arg in self.args), self.individual_node_complexity, ) score = ( - add_node_complexities(score, self.partition_spec.cumulative_node_complexity) + sum_node_complexities(score, self.partition_spec.cumulative_node_complexity) if self.partition_spec else score ) @@ -129,12 +129,12 @@ def __init__( self.args = args def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - score = add_node_complexities( + score = sum_node_complexities( *(arg.cumulative_node_complexity for arg in self.args.values()), self.individual_node_complexity, ) score = ( - add_node_complexities(score, self.partition_spec.cumulative_node_complexity) + sum_node_complexities(score, self.partition_spec.cumulative_node_complexity) if self.partition_spec else score ) @@ -148,16 +148,16 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: self.operators = operators def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - score = add_node_complexities( + score = sum_node_complexities( *(arg.cumulative_node_complexity for arg in self.args.values()), self.individual_node_complexity, ) score = ( - add_node_complexities(score, self.partition_spec.cumulative_node_complexity) + sum_node_complexities(score, self.partition_spec.cumulative_node_complexity) if self.partition_spec else score ) - score = add_node_complexities( + score = sum_node_complexities( score, {PlanNodeCategory.COLUMN.value: len(self.operators)} ) return score @@ -191,7 +191,7 @@ def __init__( @property def individual_node_complexity(self) -> Dict[str, int]: # SELECT left_cols, right_cols FROM child as left_alias JOIN table(func(...)) as right_alias - return add_node_complexities( + return sum_node_complexities( { PlanNodeCategory.COLUMN.value: len(self.left_cols) + len(self.right_cols), @@ -212,7 +212,7 @@ def __init__( @property def individual_node_complexity(self) -> Dict[str, int]: # SELECT * FROM (child), LATERAL table_func_expression - return add_node_complexities( + return sum_node_complexities( {PlanNodeCategory.COLUMN.value: 1}, self.table_function.cumulative_node_complexity, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index 418e6e5d22f..304b10044dc 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -7,7 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import Expression from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( LogicalPlan, @@ -28,7 +28,7 @@ def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # WHEN MATCHED [AND condition] THEN DEL score = self.individual_node_complexity score = ( - add_node_complexities(score, self.condition.cumulative_node_complexity) + sum_node_complexities(score, self.condition.cumulative_node_complexity) if self.condition else score ) @@ -44,10 +44,10 @@ def __init__( def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) - score = add_node_complexities( + score = sum_node_complexities( self.individual_node_complexity, *( - add_node_complexities( + sum_node_complexities( key_expr.cumulative_node_complexity, val_expr.cumulative_node_complexity, ) @@ -55,7 +55,7 @@ def calculate_cumulative_node_complexity(self) -> Dict[str, int]: ), ) score = ( - add_node_complexities(score, self.condition.cumulative_node_complexity) + sum_node_complexities(score, self.condition.cumulative_node_complexity) if self.condition else score ) @@ -79,13 +79,13 @@ def __init__( def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) - score = add_node_complexities( + score = sum_node_complexities( self.individual_node_complexity, *(key.cumulative_node_complexity for key in self.keys), *(val.cumulative_node_complexity for val in self.values), ) score = ( - add_node_complexities(score, self.condition.cumulative_node_complexity) + sum_node_complexities(score, self.condition.cumulative_node_complexity) if self.condition else score ) @@ -110,16 +110,16 @@ def __init__( @property def individual_node_complexity(self) -> Dict[str, int]: # UPDATE table_name SET COMMA.join(k, v in assignments) [source_data] [WHERE condition] - score = add_node_complexities( + score = sum_node_complexities( *( - add_node_complexities( + sum_node_complexities( k.cumulative_node_complexity, v.cumulative_node_complexity ) for k, v in self.assignments.items() ), ) score = ( - add_node_complexities(score, self.condition.cumulative_node_complexity) + sum_node_complexities(score, self.condition.cumulative_node_complexity) if self.condition else score ) @@ -163,7 +163,7 @@ def __init__( @property def individual_node_complexity(self) -> Dict[str, int]: # MERGE INTO table_name USING (source) ON join_expr clauses - return add_node_complexities( + return sum_node_complexities( self.join_expr.cumulative_node_complexity, *(clause.cumulative_node_complexity for clause in self.clauses), ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 335dafb45f2..7dcecd86475 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -11,7 +11,7 @@ ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -56,7 +56,7 @@ def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None: @property def individual_node_complexity(self) -> Dict[str, int]: # child ORDER BY COMMA.join(order) - return add_node_complexities( + return sum_node_complexities( {PlanNodeCategory.ORDER_BY.value: 1}, *(col.cumulative_node_complexity for col in self.order), ) @@ -77,7 +77,7 @@ def __init__( def individual_node_complexity(self) -> Dict[str, int]: if self.grouping_expressions: # GROUP BY grouping_exprs - score = add_node_complexities( + score = sum_node_complexities( {PlanNodeCategory.GROUP_BY.value: 1}, *( expr.cumulative_node_complexity @@ -88,7 +88,7 @@ def individual_node_complexity(self) -> Dict[str, int]: # LIMIT 1 score = {PlanNodeCategory.LOW_IMPACT.value: 1} - score = add_node_complexities( + score = sum_node_complexities( score, *( getattr( @@ -125,7 +125,7 @@ def individual_node_complexity(self) -> Dict[str, int]: # child score adjustment if grouping cols if self.grouping_columns and self.aggregates and self.aggregates[0].children: # for additional projecting cols when grouping cols is not empty - score = add_node_complexities( + score = sum_node_complexities( self.pivot_column.cumulative_node_complexity, self.aggregates[0].children[0].cumulative_node_complexity, *(col.cumulative_node_complexity for col in self.grouping_columns), @@ -133,25 +133,25 @@ def individual_node_complexity(self) -> Dict[str, int]: # pivot col if isinstance(self.pivot_values, ScalarSubquery): - score = add_node_complexities( + score = sum_node_complexities( score, self.pivot_values.cumulative_node_complexity ) elif isinstance(self.pivot_values, List): - score = add_node_complexities( + score = sum_node_complexities( score, *(val.cumulative_node_complexity for val in self.pivot_values) ) else: # if pivot values is None, then we add OTHERS for ANY - score = add_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) + score = sum_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) # aggregate score - score = add_node_complexities( + score = sum_node_complexities( score, *(expr.cumulative_node_complexity for expr in self.aggregates), ) # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) - score = add_node_complexities( + score = sum_node_complexities( score, {PlanNodeCategory.COLUMN.value: 2, PlanNodeCategory.PIVOT.value: 1} ) return score @@ -173,7 +173,7 @@ def __init__( @property def individual_node_complexity(self) -> Dict[str, int]: # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) - return add_node_complexities( + return sum_node_complexities( {PlanNodeCategory.UNPIVOT.value: 1, PlanNodeCategory.COLUMN.value: 3}, *(expr.cumulative_node_complexity for expr in self.column_list), ) @@ -205,7 +205,7 @@ def __init__(self, condition: Expression, child: LogicalPlan) -> None: @property def individual_node_complexity(self) -> Dict[str, int]: # child WHERE condition - return add_node_complexities( + return sum_node_complexities( {PlanNodeCategory.FILTER.value: 1}, self.condition.cumulative_node_complexity, ) @@ -221,7 +221,7 @@ def individual_node_complexity(self) -> Dict[str, int]: if not self.project_list: return {PlanNodeCategory.COLUMN.value: 1} - return add_node_complexities( + return sum_node_complexities( *( getattr( col, diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index a148632d1d5..9fdf40ff9ed 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -10,7 +10,7 @@ ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -77,7 +77,7 @@ def plan_node_category(self) -> PlanNodeCategory: def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # frame_type BETWEEN lower AND upper - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, self.lower.cumulative_node_complexity, self.upper.cumulative_node_complexity, @@ -105,12 +105,12 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def individual_node_complexity(self) -> Dict[str, int]: score = {} score = ( - add_node_complexities(score, {PlanNodeCategory.PARTITION_BY.value: 1}) + sum_node_complexities(score, {PlanNodeCategory.PARTITION_BY.value: 1}) if self.partition_spec else score ) score = ( - add_node_complexities(score, {PlanNodeCategory.ORDER_BY.value: 1}) + sum_node_complexities(score, {PlanNodeCategory.ORDER_BY.value: 1}) if self.order_spec else score ) @@ -118,7 +118,7 @@ def individual_node_complexity(self) -> Dict[str, int]: def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # partition_spec order_by_spec frame_spec - return add_node_complexities( + return sum_node_complexities( self.individual_node_complexity, self.frame_spec.cumulative_node_complexity, *(expr.cumulative_node_complexity for expr in self.partition_spec), @@ -143,7 +143,7 @@ def plan_node_category(self) -> PlanNodeCategory: def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # window_function OVER ( window_spec ) - return add_node_complexities( + return sum_node_complexities( self.window_function.cumulative_node_complexity, self.window_spec.cumulative_node_complexity, self.individual_node_complexity, @@ -175,14 +175,14 @@ def individual_node_complexity(self) -> Dict[str, int]: score = {PlanNodeCategory.FUNCTION.value: 1} # for offset score = ( - add_node_complexities(score, {PlanNodeCategory.LITERAL.value: 1}) + sum_node_complexities(score, {PlanNodeCategory.LITERAL.value: 1}) if self.offset else score ) # for ignore nulls score = ( - add_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) + sum_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) if self.ignore_nulls else score ) @@ -190,11 +190,11 @@ def individual_node_complexity(self) -> Dict[str, int]: def calculate_cumulative_node_complexity(self) -> Dict[str, int]: # func_name (expr [, offset] [, default]) [IGNORE NULLS] - score = add_node_complexities( + score = sum_node_complexities( self.individual_node_complexity, self.expr.cumulative_node_complexity ) score = ( - add_node_complexities(score, self.default.cumulative_node_complexity) + sum_node_complexities(score, self.default.cumulative_node_complexity) if self.default else score ) diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 133ea5c5917..29cf7623049 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -9,7 +9,7 @@ from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, - add_node_complexities, + sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.select_statement import ( SET_EXCEPT, @@ -115,7 +115,7 @@ def test_generator_table_function(session: Session): # adds SELECT * from () ORDER BY seq ASC NULLS FIRST assert_df_subtree_query_complexity( df2, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df1), { PlanNodeCategory.ORDER_BY.value: 1, @@ -277,7 +277,7 @@ def test_window_function(session: Session): df2 = df1.select(avg("value").over(window2).as_("window2")) assert_df_subtree_query_complexity( df2, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df1), { PlanNodeCategory.ORDER_BY.value: 1, @@ -321,7 +321,7 @@ def test_join_statement(session: Session, sample_table: str): # SELECT * FROM ((ch1) AS SNOWPARK_LEFT INNER JOIN ( ch2) AS SNOWPARK_RIGHT ON (("l_k7b8_A" = "r_e09m_A") AND ("l_k7b8_B" = "r_e09m_B"))) assert_df_subtree_query_complexity( df4, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.LOW_IMPACT.value: 3}, ), @@ -331,7 +331,7 @@ def test_join_statement(session: Session, sample_table: str): # SELECT * FROM ( (ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT USING (a, b)) assert_df_subtree_query_complexity( df5, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN.value: 2} ), ) @@ -454,7 +454,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df4 = df3.sort(col("b").asc()) assert_df_subtree_query_complexity( df4, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df3), { PlanNodeCategory.COLUMN.value: 1, @@ -468,7 +468,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df5 = df4.sort(col("c").desc()) assert_df_subtree_query_complexity( df5, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df4), {PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}, ), @@ -479,7 +479,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df6 = df5.filter(col("b") > 2) assert_df_subtree_query_complexity( df6, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df5), { PlanNodeCategory.FILTER.value: 1, @@ -494,7 +494,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df7 = df6.filter(col("c") > 3) assert_df_subtree_query_complexity( df7, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df6), { PlanNodeCategory.COLUMN.value: 1, @@ -508,7 +508,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df8 = df3.union_all(df4).union_all(df5) assert_df_subtree_query_complexity( df8, - add_node_complexities( + sum_node_complexities( *(get_cumulative_node_complexity(df) for df in [df3, df4, df5]), {PlanNodeCategory.SET_OPERATION.value: 2}, ), @@ -518,7 +518,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df9 = df8.union_all(df6).union_all(df7) assert_df_subtree_query_complexity( df9, - add_node_complexities( + sum_node_complexities( *(get_cumulative_node_complexity(df) for df in [df6, df7, df8]), {PlanNodeCategory.SET_OPERATION.value: 2}, ), @@ -528,7 +528,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df10 = df9.limit(2) assert_df_subtree_query_complexity( df10, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT.value: 1} ), ) @@ -537,7 +537,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df11 = df9.limit(3, offset=1) assert_df_subtree_query_complexity( df11, - add_node_complexities( + sum_node_complexities( get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT.value: 2} ), ) From 4bd8ce671378bbfd46962988652100a7856c449a Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 13 Jun 2024 11:07:29 -0700 Subject: [PATCH 33/37] fix type hints using enums and do not count alias twice --- .../_internal/analyzer/binary_plan_node.py | 26 +-- .../snowpark/_internal/analyzer/expression.py | 52 ++--- .../_internal/analyzer/grouping_set.py | 2 +- .../analyzer/query_plan_analysis_utils.py | 9 +- .../_internal/analyzer/select_statement.py | 56 ++--- .../_internal/analyzer/snowflake_plan.py | 9 +- .../_internal/analyzer/snowflake_plan_node.py | 54 ++--- .../_internal/analyzer/sort_expression.py | 3 +- .../_internal/analyzer/table_function.py | 83 ++++---- .../analyzer/table_merge_expression.py | 52 ++--- .../_internal/analyzer/unary_expression.py | 8 +- .../_internal/analyzer/unary_plan_node.py | 81 ++++---- .../_internal/analyzer/window_expression.py | 54 ++--- tests/integ/test_query_plan_analysis.py | 192 +++++++++--------- tests/unit/test_query_plan_analysis.py | 52 +++-- 15 files changed, 375 insertions(+), 358 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index a25a99a7804..f7d37c4ae14 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -200,25 +200,27 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.JOIN @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond - score = {self.plan_node_category.value: 1} + complexity = {self.plan_node_category: 1} if isinstance(self.join_type, UsingJoin) and self.join_type.using_columns: - score = sum_node_complexities( - score, - {PlanNodeCategory.COLUMN.value: len(self.join_type.using_columns)}, + complexity = sum_node_complexities( + complexity, + {PlanNodeCategory.COLUMN: len(self.join_type.using_columns)}, + ) + complexity = ( + sum_node_complexities( + complexity, self.join_condition.cumulative_node_complexity ) - score = ( - sum_node_complexities(score, self.join_condition.cumulative_node_complexity) if self.join_condition - else score + else complexity ) - score = ( + complexity = ( sum_node_complexities( - score, self.match_condition.cumulative_node_complexity + complexity, self.match_condition.cumulative_node_complexity ) if self.match_condition - else score + else complexity ) - return score + return complexity diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 9dc7fd657c4..b0fa97fe8e3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -62,7 +62,7 @@ def __init__(self, child: Optional["Expression"] = None) -> None: self.nullable = True self.children = [child] if child else None self.datatype: Optional[DataType] = None - self._cumulative_node_complexity: Optional[Dict[str, int]] = None + self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. @@ -93,11 +93,11 @@ def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.OTHERS @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: """Returns the individual contribution of the expression node towards the overall compilation complexity of the generated sql. """ - return {self.plan_node_category.value: 1} + return {self.plan_node_category: 1} def calculate_cumulative_node_complexity(self): children = self.children or [] @@ -107,7 +107,7 @@ def calculate_cumulative_node_complexity(self): ) @property - def cumulative_node_complexity(self) -> Dict[str, int]: + def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: """Returns the aggregate sum complexity statistic from the subtree rooted at this expression node. Statistic of current node is included in the final aggregate. """ @@ -118,7 +118,7 @@ def cumulative_node_complexity(self) -> Dict[str, int]: return self._cumulative_node_complexity @cumulative_node_complexity.setter - def cumulative_node_complexity(self, value: Dict[str, int]): + def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): self._cumulative_node_complexity = value @@ -146,7 +146,7 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.plan.cumulative_node_complexity @@ -158,7 +158,7 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( *(expr.cumulative_node_complexity for expr in self.expressions), ) @@ -177,7 +177,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.individual_node_complexity, self.columns.cumulative_node_complexity, @@ -229,13 +229,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: if self.expressions: return {} # if there are no expressions, we assign column value = 1 to Star - return {PlanNodeCategory.COLUMN.value: 1} + return {PlanNodeCategory.COLUMN: 1} - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.individual_node_complexity, *(child.individual_node_complexity for child in self.expressions), @@ -373,7 +373,7 @@ def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.individual_node_complexity, self.expr.cumulative_node_complexity, @@ -395,7 +395,7 @@ def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.individual_node_complexity, self.expr.cumulative_node_complexity, @@ -417,7 +417,7 @@ def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.expr.cumulative_node_complexity, self.individual_node_complexity ) @@ -437,7 +437,7 @@ def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field return PlanNodeCategory.LITERAL - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # self.expr ( self.field ) return sum_node_complexities( self.expr.cumulative_node_complexity, self.individual_node_complexity @@ -458,7 +458,7 @@ def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field return PlanNodeCategory.LITERAL - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # self.expr ( self.field ) return sum_node_complexities( self.expr.cumulative_node_complexity, self.individual_node_complexity @@ -516,7 +516,7 @@ def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) return PlanNodeCategory.ORDER_BY - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.individual_node_complexity, self.expr.cumulative_node_complexity, @@ -546,8 +546,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.CASE_WHEN - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - score = sum_node_complexities( + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = sum_node_complexities( self.individual_node_complexity, *( sum_node_complexities( @@ -557,12 +557,14 @@ def calculate_cumulative_node_complexity(self) -> Dict[str, int]: for condition, value in self.branches ), ) - score = ( - sum_node_complexities(score, self.else_value.cumulative_node_complexity) + complexity = ( + sum_node_complexities( + complexity, self.else_value.cumulative_node_complexity + ) if self.else_value - else score + else complexity ) - return score + return complexity class SnowflakeUDF(Expression): @@ -588,7 +590,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.individual_node_complexity, *(expr.cumulative_node_complexity for expr in self.children), @@ -609,7 +611,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.col.cumulative_node_complexity, self.individual_node_complexity ) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 91b170a5dbd..56428cfaf05 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -45,7 +45,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( *( sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py index 69a39586c2b..f3b4980b698 100644 --- a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py @@ -36,8 +36,15 @@ class PlanNodeCategory(Enum): LOW_IMPACT = "low_impact" OTHERS = "others" + def __repr__(self): + return self.value -def sum_node_complexities(*node_complexities: Dict[str, int]) -> Dict[str, int]: + +def sum_node_complexities( + *node_complexities: Dict[PlanNodeCategory, int] +) -> Dict[PlanNodeCategory, int]: + """This is a helper function to sum complexity values from all complexity dictionaries. A node + complexity is a dictionary of node category to node count mapping""" counter_sum = sum( (Counter(complexity) for complexity in node_complexities), Counter() ) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index f4ad87f737c..a64f3d084cb 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -203,7 +203,7 @@ def __init__( str, Dict[str, str] ] = defaultdict(dict) self._api_calls = api_calls.copy() if api_calls is not None else None - self._cumulative_node_complexity: Optional[Dict[str, int]] = None + self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None def __eq__(self, other: "Selectable") -> bool: if self._id is not None and other._id is not None: @@ -295,7 +295,7 @@ def num_duplicate_nodes(self) -> int: return self.snowflake_plan.num_duplicate_nodes @property - def cumulative_node_complexity(self) -> Dict[str, int]: + def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: if self._cumulative_node_complexity is None: self._cumulative_node_complexity = sum_node_complexities( self.individual_node_complexity, @@ -304,7 +304,7 @@ def cumulative_node_complexity(self) -> Dict[str, int]: return self._cumulative_node_complexity @cumulative_node_complexity.setter - def cumulative_node_complexity(self, value: Dict[str, int]): + def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): self._cumulative_node_complexity = value @property @@ -492,7 +492,7 @@ def query_params(self) -> Optional[Sequence[Any]]: return self._query_params @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.snowflake_plan.individual_node_complexity @@ -689,59 +689,59 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return [self.from_] @property - def individual_node_complexity(self) -> Dict[str, int]: - score = {} + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = {} # projection component - score = ( + complexity = ( sum_node_complexities( - score, + complexity, *( getattr( expr, "cumulative_node_complexity", - {PlanNodeCategory.COLUMN.value: 1}, + {PlanNodeCategory.COLUMN: 1}, ) # type: ignore for expr in self.projection ), ) if self.projection - else score + else complexity ) # filter component - add +1 for WHERE clause and sum of expression complexity for where expression - score = ( + complexity = ( sum_node_complexities( - score, - {PlanNodeCategory.FILTER.value: 1}, + complexity, + {PlanNodeCategory.FILTER: 1}, self.where.cumulative_node_complexity, ) if self.where - else score + else complexity ) # order by component - add complexity for each sort expression - score = ( + complexity = ( sum_node_complexities( - score, + complexity, *(expr.cumulative_node_complexity for expr in self.order_by), - {PlanNodeCategory.ORDER_BY.value: 1}, + {PlanNodeCategory.ORDER_BY: 1}, ) if self.order_by - else score + else complexity ) # limit/offset component - score = ( - sum_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.LOW_IMPACT: 1}) if self.limit_ - else score + else complexity ) - score = ( - sum_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.LOW_IMPACT: 1}) if self.offset - else score + else complexity ) - return score + return complexity def to_subqueryable(self) -> "Selectable": """When this SelectStatement's subquery is not subqueryable (can't be used in `from` clause of the sql), @@ -1046,7 +1046,7 @@ def query_params(self) -> Optional[Sequence[Any]]: return self.snowflake_plan.queries[-1].params @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.snowflake_plan.individual_node_complexity @@ -1129,9 +1129,9 @@ def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return self._nodes @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # we add #set_operands - 1 additional operators in sql query - return {PlanNodeCategory.SET_OPERATION.value: len(self.set_operands) - 1} + return {PlanNodeCategory.SET_OPERATION: len(self.set_operands) - 1} class DeriveColumnDependencyError(Exception): diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 968d64e6208..7d8c5ab7e7f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -22,6 +22,7 @@ ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, sum_node_complexities, ) from snowflake.snowpark._internal.analyzer.table_function import ( @@ -235,7 +236,7 @@ def __init__( self.placeholder_query = placeholder_query # encode an id for CTE optimization self._id = encode_id(queries[-1].sql, queries[-1].params) - self._cumulative_node_complexity: Optional[Dict[str, int]] = None + self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None def __eq__(self, other: "SnowflakePlan") -> bool: if self._id is not None and other._id is not None: @@ -355,13 +356,13 @@ def num_duplicate_nodes(self) -> int: return len(find_duplicate_subtrees(self)) @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: if self.source_plan: return self.source_plan.individual_node_complexity return {} @property - def cumulative_node_complexity(self) -> Dict[str, int]: + def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: if self._cumulative_node_complexity is None: self._cumulative_node_complexity = sum_node_complexities( self.individual_node_complexity, @@ -370,7 +371,7 @@ def cumulative_node_complexity(self) -> Dict[str, int]: return self._cumulative_node_complexity @cumulative_node_complexity.setter - def cumulative_node_complexity(self, value: Dict[str, int]): + def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): self._cumulative_node_complexity = value def __copy__(self) -> "SnowflakePlan": diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index d52b0c9fe54..5add46ba5eb 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -27,21 +27,21 @@ class LogicalPlan: def __init__(self) -> None: self.children = [] - self._cumulative_node_complexity: Optional[Dict[str, int]] = None + self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.OTHERS @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: """Returns the individual contribution of the logical plan node towards the overall compilation complexity of the generated sql. """ - return {self.plan_node_category.value: 1} + return {self.plan_node_category: 1} @property - def cumulative_node_complexity(self) -> Dict[str, int]: + def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: """Returns the aggregate sum complexity statistic from the subtree rooted at this logical plan node. Statistic of current node is included in the final aggregate. """ @@ -53,7 +53,7 @@ def cumulative_node_complexity(self) -> Dict[str, int]: return self._cumulative_node_complexity @cumulative_node_complexity.setter - def cumulative_node_complexity(self, value: Dict[str, int]): + def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): self._cumulative_node_complexity = value @@ -72,14 +72,14 @@ def __init__(self, start: int, end: int, step: int, num_slices: int = 1) -> None self.num_slices = num_slices @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (step) + (start) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => count))) return { - PlanNodeCategory.WINDOW.value: 1, - PlanNodeCategory.ORDER_BY.value: 1, - PlanNodeCategory.LITERAL.value: 3, # step, start, count - PlanNodeCategory.COLUMN.value: 1, # id column - PlanNodeCategory.LOW_IMPACT.value: 2, # ROW_NUMBER, GENERATOR + PlanNodeCategory.WINDOW: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.LITERAL: 3, # step, start, count + PlanNodeCategory.COLUMN: 1, # id column + PlanNodeCategory.LOW_IMPACT: 2, # ROW_NUMBER, GENERATOR } @@ -89,9 +89,9 @@ def __init__(self, name: str) -> None: self.name = name @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # SELECT * FROM name - return {PlanNodeCategory.COLUMN.value: 1} + return {PlanNodeCategory.COLUMN: 1} class SnowflakeValues(LeafNode): @@ -107,12 +107,12 @@ def __init__( self.schema_query = schema_query @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) # TODO: use ARRAY_BIND_THRESHOLD return { - PlanNodeCategory.COLUMN.value: len(self.output), - PlanNodeCategory.LITERAL.value: len(self.data) * len(self.output), + PlanNodeCategory.COLUMN: len(self.output), + PlanNodeCategory.LITERAL: len(self.data) * len(self.output), } @@ -146,25 +146,25 @@ def __init__( self.comment = comment @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (query) - score = {PlanNodeCategory.COLUMN.value: 1} - score = ( + complexity = {PlanNodeCategory.COLUMN: 1} + complexity = ( sum_node_complexities( - score, {PlanNodeCategory.COLUMN.value: len(self.column_names)} + complexity, {PlanNodeCategory.COLUMN: len(self.column_names)} ) if self.column_names - else score + else complexity ) - score = ( + complexity = ( sum_node_complexities( - score, + complexity, *(expr.cumulative_node_complexity for expr in self.clustering_exprs), ) if self.clustering_exprs - else score + else complexity ) - return score + return complexity class Limit(LogicalPlan): @@ -178,10 +178,10 @@ def __init__( self.children.append(child) @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # for limit and offset return sum_node_complexities( - {PlanNodeCategory.LOW_IMPACT.value: 2}, + {PlanNodeCategory.LOW_IMPACT: 2}, self.limit_expr.cumulative_node_complexity, self.offset_expr.cumulative_node_complexity, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 8e8989e5e1d..44f352a6813 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -9,6 +9,7 @@ derive_dependent_columns, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, sum_node_complexities, ) @@ -59,7 +60,7 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.child.cumulative_node_complexity, self.individual_node_complexity ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 4144eed4989..540aba83b6f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -34,29 +34,29 @@ def __init__( self.partition_spec = partition_spec self.order_spec = order_spec - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: if not self.over: return {} - score = {PlanNodeCategory.WINDOW.value: 1} - score = ( + complexity = {PlanNodeCategory.WINDOW: 1} + complexity = ( sum_node_complexities( - score, + complexity, *(expr.cumulative_node_complexity for expr in self.partition_spec), - {PlanNodeCategory.PARTITION_BY.value: 1}, + {PlanNodeCategory.PARTITION_BY: 1}, ) if self.partition_spec - else score + else complexity ) - score = ( + complexity = ( sum_node_complexities( - score, + complexity, *(expr.cumulative_node_complexity for expr in self.order_spec), - {PlanNodeCategory.ORDER_BY.value: 1}, + {PlanNodeCategory.ORDER_BY: 1}, ) if self.order_spec - else score + else complexity ) - return score + return complexity class TableFunctionExpression(Expression): @@ -89,7 +89,7 @@ def __init__( self.recursive = recursive self.mode = mode - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( self.individual_node_complexity, self.input.cumulative_node_complexity ) @@ -105,17 +105,19 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - score = sum_node_complexities( + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = sum_node_complexities( *(arg.cumulative_node_complexity for arg in self.args), self.individual_node_complexity, ) - score = ( - sum_node_complexities(score, self.partition_spec.cumulative_node_complexity) + complexity = ( + sum_node_complexities( + complexity, self.partition_spec.cumulative_node_complexity + ) if self.partition_spec - else score + else complexity ) - return score + return complexity class NamedArgumentsTableFunction(TableFunctionExpression): @@ -128,17 +130,19 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - score = sum_node_complexities( + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = sum_node_complexities( *(arg.cumulative_node_complexity for arg in self.args.values()), self.individual_node_complexity, ) - score = ( - sum_node_complexities(score, self.partition_spec.cumulative_node_complexity) + complexity = ( + sum_node_complexities( + complexity, self.partition_spec.cumulative_node_complexity + ) if self.partition_spec - else score + else complexity ) - return score + return complexity class GeneratorTableFunction(TableFunctionExpression): @@ -147,20 +151,22 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: self.args = args self.operators = operators - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: - score = sum_node_complexities( + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = sum_node_complexities( *(arg.cumulative_node_complexity for arg in self.args.values()), self.individual_node_complexity, ) - score = ( - sum_node_complexities(score, self.partition_spec.cumulative_node_complexity) + complexity = ( + sum_node_complexities( + complexity, self.partition_spec.cumulative_node_complexity + ) if self.partition_spec - else score + else complexity ) - score = sum_node_complexities( - score, {PlanNodeCategory.COLUMN.value: len(self.operators)} + complexity = sum_node_complexities( + complexity, {PlanNodeCategory.COLUMN: len(self.operators)} ) - return score + return complexity class TableFunctionRelation(LogicalPlan): @@ -169,7 +175,7 @@ def __init__(self, table_function: TableFunctionExpression) -> None: self.table_function = table_function @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # SELECT * FROM table_function return self.table_function.cumulative_node_complexity @@ -189,13 +195,12 @@ def __init__( self.right_cols = right_cols if right_cols is not None else ["*"] @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # SELECT left_cols, right_cols FROM child as left_alias JOIN table(func(...)) as right_alias return sum_node_complexities( { - PlanNodeCategory.COLUMN.value: len(self.left_cols) - + len(self.right_cols), - PlanNodeCategory.JOIN.value: 1, + PlanNodeCategory.COLUMN: len(self.left_cols) + len(self.right_cols), + PlanNodeCategory.JOIN: 1, }, self.table_function.cumulative_node_complexity, ) @@ -210,9 +215,9 @@ def __init__( self.table_function = table_function @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # SELECT * FROM (child), LATERAL table_func_expression return sum_node_complexities( - {PlanNodeCategory.COLUMN.value: 1}, + {PlanNodeCategory.COLUMN: 1}, self.table_function.cumulative_node_complexity, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index 304b10044dc..d30bf8d0ccd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -24,15 +24,15 @@ def __init__(self, condition: Optional[Expression]) -> None: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # WHEN MATCHED [AND condition] THEN DEL - score = self.individual_node_complexity - score = ( - sum_node_complexities(score, self.condition.cumulative_node_complexity) + complexity = self.individual_node_complexity + complexity = ( + sum_node_complexities(complexity, self.condition.cumulative_node_complexity) if self.condition - else score + else complexity ) - return score + return complexity class UpdateMergeExpression(MergeExpression): @@ -42,9 +42,9 @@ def __init__( super().__init__(condition) self.assignments = assignments - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) - score = sum_node_complexities( + complexity = sum_node_complexities( self.individual_node_complexity, *( sum_node_complexities( @@ -54,12 +54,12 @@ def calculate_cumulative_node_complexity(self) -> Dict[str, int]: for key_expr, val_expr in self.assignments.items() ), ) - score = ( - sum_node_complexities(score, self.condition.cumulative_node_complexity) + complexity = ( + sum_node_complexities(complexity, self.condition.cumulative_node_complexity) if self.condition - else score + else complexity ) - return score + return complexity class DeleteMergeExpression(MergeExpression): @@ -77,19 +77,19 @@ def __init__( self.keys = keys self.values = values - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) - score = sum_node_complexities( + complexity = sum_node_complexities( self.individual_node_complexity, *(key.cumulative_node_complexity for key in self.keys), *(val.cumulative_node_complexity for val in self.values), ) - score = ( - sum_node_complexities(score, self.condition.cumulative_node_complexity) + complexity = ( + sum_node_complexities(complexity, self.condition.cumulative_node_complexity) if self.condition - else score + else complexity ) - return score + return complexity class TableUpdate(LogicalPlan): @@ -108,9 +108,9 @@ def __init__( self.children = [source_data] if source_data else [] @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # UPDATE table_name SET COMMA.join(k, v in assignments) [source_data] [WHERE condition] - score = sum_node_complexities( + complexity = sum_node_complexities( *( sum_node_complexities( k.cumulative_node_complexity, v.cumulative_node_complexity @@ -118,12 +118,12 @@ def individual_node_complexity(self) -> Dict[str, int]: for k, v in self.assignments.items() ), ) - score = ( - sum_node_complexities(score, self.condition.cumulative_node_complexity) + complexity = ( + sum_node_complexities(complexity, self.condition.cumulative_node_complexity) if self.condition - else score + else complexity ) - return score + return complexity class TableDelete(LogicalPlan): @@ -140,7 +140,7 @@ def __init__( self.children = [source_data] if source_data else [] @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # DELETE FROM table_name [USING source_data] [WHERE condition] return self.condition.cumulative_node_complexity if self.condition else {} @@ -161,7 +161,7 @@ def __init__( self.children = [source] if source else [] @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # MERGE INTO table_name USING (source) ON join_expr clauses return sum_node_complexities( self.join_expr.cumulative_node_complexity, diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index afc176a0f41..e5886e11069 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -88,9 +88,9 @@ def __str__(self): return f"{self.child} {self.sql_operator} {self.name}" @property - def plan_node_category(self) -> PlanNodeCategory: - # child AS name - return PlanNodeCategory.COLUMN + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # do not add additional complexity for alias + return {} class UnresolvedAlias(UnaryExpression, NamedExpression): @@ -102,6 +102,6 @@ def __init__(self, child: Expression) -> None: self.name = child.sql @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # this is a wrapper around child return {} diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index 7dcecd86475..46ff69498bb 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -38,13 +38,13 @@ def __init__( self.seed = seed @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # SELECT * FROM (child) SAMPLE (probability) -- if probability is provided # SELECT * FROM (child) SAMPLE (row_count ROWS) -- if not probability but row count is provided return { - PlanNodeCategory.SAMPLE.value: 1, - PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.COLUMN.value: 1, + PlanNodeCategory.SAMPLE: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.COLUMN: 1, } @@ -54,10 +54,10 @@ def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None: self.order = order @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # child ORDER BY COMMA.join(order) return sum_node_complexities( - {PlanNodeCategory.ORDER_BY.value: 1}, + {PlanNodeCategory.ORDER_BY: 1}, *(col.cumulative_node_complexity for col in self.order), ) @@ -74,11 +74,11 @@ def __init__( self.aggregate_expressions = aggregate_expressions @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: if self.grouping_expressions: # GROUP BY grouping_exprs - score = sum_node_complexities( - {PlanNodeCategory.GROUP_BY.value: 1}, + complexity = sum_node_complexities( + {PlanNodeCategory.GROUP_BY: 1}, *( expr.cumulative_node_complexity for expr in self.grouping_expressions @@ -86,20 +86,20 @@ def individual_node_complexity(self) -> Dict[str, int]: ) else: # LIMIT 1 - score = {PlanNodeCategory.LOW_IMPACT.value: 1} + complexity = {PlanNodeCategory.LOW_IMPACT: 1} - score = sum_node_complexities( - score, + complexity = sum_node_complexities( + complexity, *( getattr( expr, "cumulative_node_complexity", - {PlanNodeCategory.COLUMN.value: 1}, + {PlanNodeCategory.COLUMN: 1}, ) # type: ignore for expr in self.aggregate_expressions ), ) - return score + return complexity class Pivot(UnaryNode): @@ -120,12 +120,12 @@ def __init__( self.default_on_null = default_on_null @property - def individual_node_complexity(self) -> Dict[str, int]: - score = {} - # child score adjustment if grouping cols + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = {} + # child complexity adjustment if grouping cols if self.grouping_columns and self.aggregates and self.aggregates[0].children: # for additional projecting cols when grouping cols is not empty - score = sum_node_complexities( + complexity = sum_node_complexities( self.pivot_column.cumulative_node_complexity, self.aggregates[0].children[0].cumulative_node_complexity, *(col.cumulative_node_complexity for col in self.grouping_columns), @@ -133,28 +133,31 @@ def individual_node_complexity(self) -> Dict[str, int]: # pivot col if isinstance(self.pivot_values, ScalarSubquery): - score = sum_node_complexities( - score, self.pivot_values.cumulative_node_complexity + complexity = sum_node_complexities( + complexity, self.pivot_values.cumulative_node_complexity ) elif isinstance(self.pivot_values, List): - score = sum_node_complexities( - score, *(val.cumulative_node_complexity for val in self.pivot_values) + complexity = sum_node_complexities( + complexity, + *(val.cumulative_node_complexity for val in self.pivot_values), ) else: # if pivot values is None, then we add OTHERS for ANY - score = sum_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) + complexity = sum_node_complexities( + complexity, {PlanNodeCategory.LOW_IMPACT: 1} + ) - # aggregate score - score = sum_node_complexities( - score, + # aggregate complexity + complexity = sum_node_complexities( + complexity, *(expr.cumulative_node_complexity for expr in self.aggregates), ) # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) - score = sum_node_complexities( - score, {PlanNodeCategory.COLUMN.value: 2, PlanNodeCategory.PIVOT.value: 1} + complexity = sum_node_complexities( + complexity, {PlanNodeCategory.COLUMN: 2, PlanNodeCategory.PIVOT: 1} ) - return score + return complexity class Unpivot(UnaryNode): @@ -171,10 +174,10 @@ def __init__( self.column_list = column_list @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) return sum_node_complexities( - {PlanNodeCategory.UNPIVOT.value: 1, PlanNodeCategory.COLUMN.value: 3}, + {PlanNodeCategory.UNPIVOT: 1, PlanNodeCategory.COLUMN: 3}, *(expr.cumulative_node_complexity for expr in self.column_list), ) @@ -189,11 +192,11 @@ def __init__( self.column_map = column_map @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # SELECT * RENAME (before AS after, ...) FROM child return { - PlanNodeCategory.COLUMN.value: 1 + 2 * len(self.column_map), - PlanNodeCategory.LOW_IMPACT.value: 1 + len(self.column_map), + PlanNodeCategory.COLUMN: 1 + len(self.column_map), + PlanNodeCategory.LOW_IMPACT: 1 + len(self.column_map), } @@ -203,10 +206,10 @@ def __init__(self, condition: Expression, child: LogicalPlan) -> None: self.condition = condition @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # child WHERE condition return sum_node_complexities( - {PlanNodeCategory.FILTER.value: 1}, + {PlanNodeCategory.FILTER: 1}, self.condition.cumulative_node_complexity, ) @@ -217,16 +220,16 @@ def __init__(self, project_list: List[NamedExpression], child: LogicalPlan) -> N self.project_list = project_list @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: if not self.project_list: - return {PlanNodeCategory.COLUMN.value: 1} + return {PlanNodeCategory.COLUMN: 1} return sum_node_complexities( *( getattr( col, "cumulative_node_complexity", - {PlanNodeCategory.COLUMN.value: 1}, + {PlanNodeCategory.COLUMN: 1}, ) # type: ignore for col in self.project_list ), diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 9fdf40ff9ed..474ac110049 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -75,7 +75,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # frame_type BETWEEN lower AND upper return sum_node_complexities( self.individual_node_complexity, @@ -102,21 +102,21 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: ) @property - def individual_node_complexity(self) -> Dict[str, int]: - score = {} - score = ( - sum_node_complexities(score, {PlanNodeCategory.PARTITION_BY.value: 1}) + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = {} + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.PARTITION_BY: 1}) if self.partition_spec - else score + else complexity ) - score = ( - sum_node_complexities(score, {PlanNodeCategory.ORDER_BY.value: 1}) + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.ORDER_BY: 1}) if self.order_spec - else score + else complexity ) - return score + return complexity - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # partition_spec order_by_spec frame_spec return sum_node_complexities( self.individual_node_complexity, @@ -141,7 +141,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # window_function OVER ( window_spec ) return sum_node_complexities( self.window_function.cumulative_node_complexity, @@ -170,35 +170,35 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) @property - def individual_node_complexity(self) -> Dict[str, int]: + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # for func_name - score = {PlanNodeCategory.FUNCTION.value: 1} + complexity = {PlanNodeCategory.FUNCTION: 1} # for offset - score = ( - sum_node_complexities(score, {PlanNodeCategory.LITERAL.value: 1}) + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.LITERAL: 1}) if self.offset - else score + else complexity ) # for ignore nulls - score = ( - sum_node_complexities(score, {PlanNodeCategory.LOW_IMPACT.value: 1}) + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.LOW_IMPACT: 1}) if self.ignore_nulls - else score + else complexity ) - return score + return complexity - def calculate_cumulative_node_complexity(self) -> Dict[str, int]: + def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # func_name (expr [, offset] [, default]) [IGNORE NULLS] - score = sum_node_complexities( + complexity = sum_node_complexities( self.individual_node_complexity, self.expr.cumulative_node_complexity ) - score = ( - sum_node_complexities(score, self.default.cumulative_node_complexity) + complexity = ( + sum_node_complexities(complexity, self.default.cumulative_node_complexity) if self.default - else score + else complexity ) - return score + return complexity class Lag(RankRelatedFunctionExpression): diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 29cf7623049..08374163383 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -67,20 +67,20 @@ def test_create_dataframe_from_values(session: Session): df1 = session.create_dataframe([[1], [2], [3]], schema=["a"]) # SELECT "A" FROM ( SELECT $1 AS "A" FROM VALUES (1 :: INT), (2 :: INT), (3 :: INT)) assert_df_subtree_query_complexity( - df1, {PlanNodeCategory.LITERAL.value: 3, PlanNodeCategory.COLUMN.value: 2} + df1, {PlanNodeCategory.LITERAL: 3, PlanNodeCategory.COLUMN: 2} ) df2 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"]) # SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT), (5 :: INT, 6 :: INT)) assert_df_subtree_query_complexity( - df2, {PlanNodeCategory.LITERAL.value: 6, PlanNodeCategory.COLUMN.value: 4} + df2, {PlanNodeCategory.LITERAL: 6, PlanNodeCategory.COLUMN: 4} ) def test_session_table(session: Session, sample_table: str): # select * from sample_table df = session.table(sample_table) - assert_df_subtree_query_complexity(df, {PlanNodeCategory.COLUMN.value: 1}) + assert_df_subtree_query_complexity(df, {PlanNodeCategory.COLUMN: 1}) def test_range_statement(session: Session): @@ -89,11 +89,11 @@ def test_range_statement(session: Session): assert_df_subtree_query_complexity( df, { - PlanNodeCategory.COLUMN.value: 1, - PlanNodeCategory.LITERAL.value: 3, - PlanNodeCategory.LOW_IMPACT.value: 2, - PlanNodeCategory.ORDER_BY.value: 1, - PlanNodeCategory.WINDOW.value: 1, + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.LITERAL: 3, + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.WINDOW: 1, }, ) @@ -105,9 +105,9 @@ def test_generator_table_function(session: Session): assert_df_subtree_query_complexity( df1, { - PlanNodeCategory.COLUMN.value: 2, - PlanNodeCategory.FUNCTION.value: 1, - PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.COLUMN: 2, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.LITERAL: 1, }, ) @@ -118,9 +118,9 @@ def test_generator_table_function(session: Session): sum_node_complexities( get_cumulative_node_complexity(df1), { - PlanNodeCategory.ORDER_BY.value: 1, - PlanNodeCategory.COLUMN.value: 1, - PlanNodeCategory.OTHERS.value: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.OTHERS: 1, }, ), ) @@ -131,7 +131,7 @@ def test_join_table_function(session: Session): "select 'James' as name, 'address1 address2 address3' as addresses" ) # SelectSQL chooses num active columns as the best estimate - assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN.value: 1}) + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN: 1}) split_to_table = table_function("split_to_table") @@ -143,10 +143,10 @@ def test_join_table_function(session: Session): assert_df_subtree_query_complexity( df2, { - PlanNodeCategory.COLUMN.value: 8, - PlanNodeCategory.JOIN.value: 1, - PlanNodeCategory.FUNCTION.value: 1, - PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.COLUMN: 8, + PlanNodeCategory.JOIN: 1, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.LITERAL: 1, }, ) @@ -160,14 +160,14 @@ def test_join_table_function(session: Session): assert_df_subtree_query_complexity( df3, { - PlanNodeCategory.COLUMN.value: 6, - PlanNodeCategory.JOIN.value: 1, - PlanNodeCategory.FUNCTION.value: 1, - PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.PARTITION_BY.value: 1, - PlanNodeCategory.ORDER_BY.value: 1, - PlanNodeCategory.WINDOW.value: 1, - PlanNodeCategory.OTHERS.value: 1, + PlanNodeCategory.COLUMN: 6, + PlanNodeCategory.JOIN: 1, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.PARTITION_BY: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.WINDOW: 1, + PlanNodeCategory.OTHERS: 1, }, ) @@ -189,7 +189,7 @@ def test_set_operators(session: Session, sample_table: str, set_operator: str): # ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) set_operator ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) assert_df_subtree_query_complexity( - df, {PlanNodeCategory.COLUMN.value: 2, PlanNodeCategory.SET_OPERATION.value: 1} + df, {PlanNodeCategory.COLUMN: 2, PlanNodeCategory.SET_OPERATION: 1} ) @@ -204,38 +204,38 @@ def test_agg(session: Session, sample_table: str): assert_df_subtree_query_complexity( df1, { - PlanNodeCategory.COLUMN.value: 3, - PlanNodeCategory.LOW_IMPACT.value: 1, - PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.COLUMN: 2, + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.FUNCTION: 1, }, ) # SELECT (avg("A") + 1 :: INT) AS "ADD(AVG(A), LITERAL())" FROM ( SELECT * FROM sample_table) LIMIT 1 assert_df_subtree_query_complexity( df2, { - PlanNodeCategory.COLUMN.value: 3, - PlanNodeCategory.LOW_IMPACT.value: 2, - PlanNodeCategory.FUNCTION.value: 1, - PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.COLUMN: 2, + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.LITERAL: 1, }, ) # SELECT avg("A") AS "AVG(A)", avg(("B" + 1 :: INT)) AS "AVG_B" FROM ( SELECT * FROM sample_table) LIMIT 1 assert_df_subtree_query_complexity( df3, { - PlanNodeCategory.COLUMN.value: 5, - PlanNodeCategory.LOW_IMPACT.value: 2, - PlanNodeCategory.FUNCTION.value: 2, - PlanNodeCategory.LITERAL.value: 1, + PlanNodeCategory.COLUMN: 3, + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.FUNCTION: 2, + PlanNodeCategory.LITERAL: 1, }, ) # SELECT "A", "B", avg("C") AS "AVG(C)" FROM ( SELECT * FROM SNOWPARK_TEMP_TABLE_EV1NO4AID6) GROUP BY "A", "B" assert_df_subtree_query_complexity( df4, { - PlanNodeCategory.COLUMN.value: 7, - PlanNodeCategory.GROUP_BY.value: 1, - PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.COLUMN: 6, + PlanNodeCategory.GROUP_BY: 1, + PlanNodeCategory.FUNCTION: 1, }, ) @@ -261,14 +261,14 @@ def test_window_function(session: Session): assert_df_subtree_query_complexity( df1, { - PlanNodeCategory.PARTITION_BY.value: 1, - PlanNodeCategory.ORDER_BY.value: 1, - PlanNodeCategory.WINDOW.value: 1, - PlanNodeCategory.FUNCTION.value: 1, - PlanNodeCategory.COLUMN.value: 5, - PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.LOW_IMPACT.value: 2, - PlanNodeCategory.OTHERS.value: 1, + PlanNodeCategory.PARTITION_BY: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.WINDOW: 1, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.COLUMN: 4, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.OTHERS: 1, }, ) @@ -280,12 +280,12 @@ def test_window_function(session: Session): sum_node_complexities( get_cumulative_node_complexity(df1), { - PlanNodeCategory.ORDER_BY.value: 1, - PlanNodeCategory.WINDOW.value: 1, - PlanNodeCategory.FUNCTION.value: 1, - PlanNodeCategory.COLUMN.value: 3, - PlanNodeCategory.LOW_IMPACT.value: 3, - PlanNodeCategory.OTHERS.value: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.WINDOW: 1, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.COLUMN: 2, + PlanNodeCategory.LOW_IMPACT: 3, + PlanNodeCategory.OTHERS: 1, }, ), ) @@ -296,11 +296,11 @@ def test_window_function(session: Session): def test_join_statement(session: Session, sample_table: str): # SELECT * FROM table df1 = session.table(sample_table) - assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN.value: 1}) + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN: 1}) # SELECT A, B, E FROM (SELECT $1 AS "A", $2 AS "B", $3 AS "E" FROM VALUES (1 :: INT, 2 :: INT, 5 :: INT), (3 :: INT, 4 :: INT, 9 :: INT)) df2 = session.create_dataframe([[1, 2, 5], [3, 4, 9]], schema=["a", "b", "e"]) assert_df_subtree_query_complexity( - df2, {PlanNodeCategory.COLUMN.value: 6, PlanNodeCategory.LITERAL.value: 6} + df2, {PlanNodeCategory.COLUMN: 6, PlanNodeCategory.LITERAL: 6} ) df3 = df1.join(df2) @@ -311,9 +311,9 @@ def test_join_statement(session: Session, sample_table: str): assert_df_subtree_query_complexity( df3, { - PlanNodeCategory.COLUMN.value: 18, - PlanNodeCategory.LITERAL.value: 6, - PlanNodeCategory.JOIN.value: 1, + PlanNodeCategory.COLUMN: 11, + PlanNodeCategory.LITERAL: 6, + PlanNodeCategory.JOIN: 1, }, ) @@ -323,7 +323,7 @@ def test_join_statement(session: Session, sample_table: str): df4, sum_node_complexities( get_cumulative_node_complexity(df3), - {PlanNodeCategory.COLUMN.value: 4, PlanNodeCategory.LOW_IMPACT.value: 3}, + {PlanNodeCategory.COLUMN: 4, PlanNodeCategory.LOW_IMPACT: 3}, ), ) @@ -332,7 +332,7 @@ def test_join_statement(session: Session, sample_table: str): assert_df_subtree_query_complexity( df5, sum_node_complexities( - get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN.value: 2} + get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN: 2} ), ) @@ -358,10 +358,10 @@ def test_pivot(session: Session): assert_df_subtree_query_complexity( df_pivot1, { - PlanNodeCategory.PIVOT.value: 1, - PlanNodeCategory.COLUMN.value: 4, - PlanNodeCategory.LITERAL.value: 2, - PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.PIVOT: 1, + PlanNodeCategory.COLUMN: 4, + PlanNodeCategory.LITERAL: 2, + PlanNodeCategory.FUNCTION: 1, }, ) @@ -374,10 +374,10 @@ def test_pivot(session: Session): assert_df_subtree_query_complexity( df_pivot2, { - PlanNodeCategory.PIVOT.value: 1, - PlanNodeCategory.COLUMN.value: 4, - PlanNodeCategory.LITERAL.value: 3, - PlanNodeCategory.FUNCTION.value: 1, + PlanNodeCategory.PIVOT: 1, + PlanNodeCategory.COLUMN: 4, + PlanNodeCategory.LITERAL: 3, + PlanNodeCategory.FUNCTION: 1, }, ) finally: @@ -399,7 +399,7 @@ def test_unpivot(session: Session): # SELECT * FROM ( SELECT * FROM (sales_for_month)) UNPIVOT (sales FOR month IN ("JAN", "FEB")) assert_df_subtree_query_complexity( df_unpivot1, - {PlanNodeCategory.UNPIVOT.value: 1, PlanNodeCategory.COLUMN.value: 6}, + {PlanNodeCategory.UNPIVOT: 1, PlanNodeCategory.COLUMN: 6}, ) finally: Utils.drop_table(session, "sales_for_month") @@ -412,9 +412,9 @@ def test_sample(session: Session, sample_table): assert_df_subtree_query_complexity( df_sample_frac, { - PlanNodeCategory.SAMPLE.value: 1, - PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.COLUMN.value: 2, + PlanNodeCategory.SAMPLE: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.COLUMN: 2, }, ) @@ -423,9 +423,9 @@ def test_sample(session: Session, sample_table): assert_df_subtree_query_complexity( df_sample_rows, { - PlanNodeCategory.SAMPLE.value: 1, - PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.COLUMN.value: 2, + PlanNodeCategory.SAMPLE: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.COLUMN: 2, }, ) @@ -439,15 +439,15 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl # from select * from sample_table which is flattened out. This is a known limitation but is okay # since we are not off my much df1 = df.select(df["*"]) - assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN.value: 5}) + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN: 5}) # SELECT "A", "B", "C" FROM sample_table df2 = df1.select("a", "b", "c") - assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN.value: 4}) + assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN: 4}) # 1 less active column df3 = df2.select("b", "c") - assert_df_subtree_query_complexity(df3, {PlanNodeCategory.COLUMN.value: 3}) + assert_df_subtree_query_complexity(df3, {PlanNodeCategory.COLUMN: 3}) # add sort # for additional ORDER BY "B" ASC NULLS FIRST @@ -457,9 +457,9 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl sum_node_complexities( get_cumulative_node_complexity(df3), { - PlanNodeCategory.COLUMN.value: 1, - PlanNodeCategory.ORDER_BY.value: 1, - PlanNodeCategory.OTHERS.value: 1, + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.OTHERS: 1, }, ), ) @@ -470,7 +470,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df5, sum_node_complexities( get_cumulative_node_complexity(df4), - {PlanNodeCategory.COLUMN.value: 1, PlanNodeCategory.OTHERS.value: 1}, + {PlanNodeCategory.COLUMN: 1, PlanNodeCategory.OTHERS: 1}, ), ) @@ -482,10 +482,10 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl sum_node_complexities( get_cumulative_node_complexity(df5), { - PlanNodeCategory.FILTER.value: 1, - PlanNodeCategory.COLUMN.value: 1, - PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.LOW_IMPACT.value: 1, + PlanNodeCategory.FILTER: 1, + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.LOW_IMPACT: 1, }, ), ) @@ -497,9 +497,9 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl sum_node_complexities( get_cumulative_node_complexity(df6), { - PlanNodeCategory.COLUMN.value: 1, - PlanNodeCategory.LITERAL.value: 1, - PlanNodeCategory.LOW_IMPACT.value: 2, + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.LOW_IMPACT: 2, }, ), ) @@ -510,7 +510,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df8, sum_node_complexities( *(get_cumulative_node_complexity(df) for df in [df3, df4, df5]), - {PlanNodeCategory.SET_OPERATION.value: 2}, + {PlanNodeCategory.SET_OPERATION: 2}, ), ) @@ -520,7 +520,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl df9, sum_node_complexities( *(get_cumulative_node_complexity(df) for df in [df6, df7, df8]), - {PlanNodeCategory.SET_OPERATION.value: 2}, + {PlanNodeCategory.SET_OPERATION: 2}, ), ) @@ -529,7 +529,7 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df10, sum_node_complexities( - get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT.value: 1} + get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT: 1} ), ) @@ -538,6 +538,6 @@ def test_select_statement_with_multiple_operations(session: Session, sample_tabl assert_df_subtree_query_complexity( df11, sum_node_complexities( - get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT.value: 2} + get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT: 2} ), ) diff --git a/tests/unit/test_query_plan_analysis.py b/tests/unit/test_query_plan_analysis.py index cfd3e94c3c2..6b396692dff 100644 --- a/tests/unit/test_query_plan_analysis.py +++ b/tests/unit/test_query_plan_analysis.py @@ -77,35 +77,35 @@ def set_children(node, node_type, children): set_children(nodes[5], node_type, []) set_children(nodes[6], node_type, []) - assert nodes[0].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 7} - assert nodes[1].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 5} - assert nodes[2].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} - assert nodes[3].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} - assert nodes[4].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 2} - assert nodes[5].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} - assert nodes[6].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[0].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 7} + assert nodes[1].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 5} + assert nodes[2].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} + assert nodes[3].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} + assert nodes[4].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 2} + assert nodes[5].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} + assert nodes[6].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} - nodes[1].cumulative_node_complexity = {PlanNodeCategory.COLUMN.value: 1} + nodes[1].cumulative_node_complexity = {PlanNodeCategory.COLUMN: 1} # assert that only value that is reset is changed - assert nodes[0].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 7} - assert nodes[1].cumulative_node_complexity == {PlanNodeCategory.COLUMN.value: 1} - assert nodes[2].cumulative_node_complexity == {PlanNodeCategory.OTHERS.value: 1} + assert nodes[0].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 7} + assert nodes[1].cumulative_node_complexity == {PlanNodeCategory.COLUMN: 1} + assert nodes[2].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} def test_selectable_entity_individual_node_complexity(mock_analyzer): plan_node = SelectableEntity(entity_name="dummy entity", analyzer=mock_analyzer) - assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN.value: 1} + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 1} -def test_select_sql_individual_node_complexity(mock_session, mock_analyzer): +def test_select_sql_individual_node_complexity(mock_analyzer): plan_node = SelectSQL( "non-select statement", convert_to_select=True, analyzer=mock_analyzer ) - assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN.value: 1} + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 1} plan_node = SelectSQL("select 1 as A, 2 as B", analyzer=mock_analyzer) - assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN.value: 1} + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 1} def test_select_snowflake_plan_individual_node_complexity( @@ -116,26 +116,26 @@ def test_select_snowflake_plan_individual_node_complexity( [mock_query], "", source_plan=source_plan, session=mock_session ) plan_node = SelectSnowflakePlan(snowflake_plan, analyzer=mock_analyzer) - assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN.value: 2} + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 2} @pytest.mark.parametrize( "attribute,value,expected_stat", [ - ("projection", [NamedExpression()], {PlanNodeCategory.COLUMN.value: 1}), - ("projection", [Expression()], {PlanNodeCategory.OTHERS.value: 1}), + ("projection", [NamedExpression()], {PlanNodeCategory.COLUMN: 1}), + ("projection", [Expression()], {PlanNodeCategory.OTHERS: 1}), ( "order_by", [Expression()], - {PlanNodeCategory.OTHERS.value: 1, PlanNodeCategory.ORDER_BY.value: 1}, + {PlanNodeCategory.OTHERS: 1, PlanNodeCategory.ORDER_BY: 1}, ), ( "where", Expression(), - {PlanNodeCategory.OTHERS.value: 1, PlanNodeCategory.FILTER.value: 1}, + {PlanNodeCategory.OTHERS: 1, PlanNodeCategory.FILTER: 1}, ), - ("limit_", 10, {PlanNodeCategory.LOW_IMPACT.value: 1}), - ("offset", 2, {PlanNodeCategory.LOW_IMPACT.value: 1}), + ("limit_", 10, {PlanNodeCategory.LOW_IMPACT: 1}), + ("offset", 2, {PlanNodeCategory.LOW_IMPACT: 1}), ], ) def test_select_statement_individual_node_complexity( @@ -166,9 +166,7 @@ def mocked_resolve(*args, **kwargs): with mock.patch.object(mock_analyzer, "resolve", side_effect=mocked_resolve): plan_node = SelectTableFunction(func_expr, analyzer=mock_analyzer) - assert plan_node.individual_node_complexity == { - PlanNodeCategory.COLUMN.value: 2 - } + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 2} @pytest.mark.parametrize("set_operator", [UNION, UNION_ALL, INTERSECT, EXCEPT]) @@ -184,6 +182,4 @@ def test_set_statement_individual_node_complexity(mock_analyzer, set_operator): ] plan_node = SetStatement(*set_operands, analyzer=mock_analyzer) - assert plan_node.individual_node_complexity == { - PlanNodeCategory.SET_OPERATION.value: 1 - } + assert plan_node.individual_node_complexity == {PlanNodeCategory.SET_OPERATION: 1} From 9035a704a79659414a9479620ff3458ef9abfc1b Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 13 Jun 2024 14:26:04 -0700 Subject: [PATCH 34/37] align complexity stat calculation --- .../snowpark/_internal/analyzer/expression.py | 80 +++++++++---------- .../_internal/analyzer/grouping_set.py | 5 +- .../analyzer/query_plan_analysis_utils.py | 4 +- .../_internal/analyzer/snowflake_plan_node.py | 2 +- .../_internal/analyzer/sort_expression.py | 11 +-- .../_internal/analyzer/table_function.py | 23 +++--- .../analyzer/table_merge_expression.py | 15 ++-- .../_internal/analyzer/window_expression.py | 39 +++++---- src/snowflake/snowpark/_internal/telemetry.py | 7 +- tests/integ/test_query_plan_analysis.py | 2 +- 10 files changed, 92 insertions(+), 96 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index b0fa97fe8e3..f412252f1c3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -99,21 +99,16 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: """ return {self.plan_node_category: 1} - def calculate_cumulative_node_complexity(self): - children = self.children or [] - return sum_node_complexities( - self.individual_node_complexity, - *(child.cumulative_node_complexity for child in children), - ) - @property def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: """Returns the aggregate sum complexity statistic from the subtree rooted at this expression node. Statistic of current node is included in the final aggregate. """ if self._cumulative_node_complexity is None: - self._cumulative_node_complexity = ( - self.calculate_cumulative_node_complexity() + children = self.children or [] + self._cumulative_node_complexity = sum_node_complexities( + self.individual_node_complexity, + *(child.cumulative_node_complexity for child in children), ) return self._cumulative_node_complexity @@ -146,7 +141,8 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.plan.cumulative_node_complexity @@ -158,7 +154,8 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( *(expr.cumulative_node_complexity for expr in self.expressions), ) @@ -177,9 +174,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( - self.individual_node_complexity, + {self.plan_node_category: 1}, self.columns.cumulative_node_complexity, *(expr.cumulative_node_complexity for expr in self.values), ) @@ -230,15 +228,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: - if self.expressions: - return {} - # if there are no expressions, we assign column value = 1 to Star - return {PlanNodeCategory.COLUMN: 1} + complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1} - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( - self.individual_node_complexity, - *(child.individual_node_complexity for child in self.expressions), + complexity, + *(expr.cumulative_node_complexity for expr in self.expressions), ) @@ -373,9 +367,10 @@ def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( - self.individual_node_complexity, + {self.plan_node_category: 1}, self.expr.cumulative_node_complexity, self.pattern.cumulative_node_complexity, ) @@ -395,9 +390,10 @@ def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( - self.individual_node_complexity, + {self.plan_node_category: 1}, self.expr.cumulative_node_complexity, self.pattern.cumulative_node_complexity, ) @@ -417,9 +413,10 @@ def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( - self.expr.cumulative_node_complexity, self.individual_node_complexity + {self.plan_node_category: 1}, self.expr.cumulative_node_complexity ) @@ -437,10 +434,11 @@ def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field return PlanNodeCategory.LITERAL - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # self.expr ( self.field ) return sum_node_complexities( - self.expr.cumulative_node_complexity, self.individual_node_complexity + {self.plan_node_category: 1}, self.expr.cumulative_node_complexity ) @@ -458,10 +456,11 @@ def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field return PlanNodeCategory.LITERAL - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # self.expr ( self.field ) return sum_node_complexities( - self.expr.cumulative_node_complexity, self.individual_node_complexity + {self.plan_node_category: 1}, self.expr.cumulative_node_complexity ) @@ -516,9 +515,10 @@ def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) return PlanNodeCategory.ORDER_BY - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( - self.individual_node_complexity, + {self.plan_node_category: 1}, self.expr.cumulative_node_complexity, *(col.cumulative_node_complexity for col in self.order_by_cols), ) @@ -546,9 +546,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.CASE_WHEN - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = sum_node_complexities( - self.individual_node_complexity, + {self.plan_node_category: 1}, *( sum_node_complexities( condition.cumulative_node_complexity, @@ -590,12 +591,6 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: - return sum_node_complexities( - self.individual_node_complexity, - *(expr.cumulative_node_complexity for expr in self.children), - ) - class ListAgg(Expression): def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: @@ -611,7 +606,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( - self.col.cumulative_node_complexity, self.individual_node_complexity + {self.plan_node_category: 1}, self.col.cumulative_node_complexity ) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 56428cfaf05..84cd63fd87d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -45,15 +45,16 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( + {self.plan_node_category: 1}, *( sum_node_complexities( *(expr.cumulative_node_complexity for expr in arg) ) for arg in self.args ), - self.individual_node_complexity, ) @property diff --git a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py index f3b4980b698..057a4c37bc7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py @@ -36,8 +36,8 @@ class PlanNodeCategory(Enum): LOW_IMPACT = "low_impact" OTHERS = "others" - def __repr__(self): - return self.value + def __repr__(self) -> str: + return self.name def sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 5add46ba5eb..8dba1cf2b0b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -79,7 +79,7 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: PlanNodeCategory.ORDER_BY: 1, PlanNodeCategory.LITERAL: 3, # step, start, count PlanNodeCategory.COLUMN: 1, # id column - PlanNodeCategory.LOW_IMPACT: 2, # ROW_NUMBER, GENERATOR + PlanNodeCategory.FUNCTION: 3, # ROW_NUMBER, SEQ, GENERATOR } diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 44f352a6813..1d06f7290a0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,16 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Dict, Optional, Type +from typing import AbstractSet, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) -from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( - PlanNodeCategory, - sum_node_complexities, -) class NullOrdering: @@ -59,8 +55,3 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) - - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: - return sum_node_complexities( - self.child.cumulative_node_complexity, self.individual_node_complexity - ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 540aba83b6f..1e7d342ce8b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -34,7 +34,8 @@ def __init__( self.partition_spec = partition_spec self.order_spec = order_spec - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: if not self.over: return {} complexity = {PlanNodeCategory.WINDOW: 1} @@ -89,9 +90,10 @@ def __init__( self.recursive = recursive self.mode = mode - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( - self.individual_node_complexity, self.input.cumulative_node_complexity + {self.plan_node_category: 1}, self.input.cumulative_node_complexity ) @@ -105,10 +107,11 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = sum_node_complexities( + {self.plan_node_category: 1}, *(arg.cumulative_node_complexity for arg in self.args), - self.individual_node_complexity, ) complexity = ( sum_node_complexities( @@ -130,10 +133,11 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = sum_node_complexities( + {self.plan_node_category: 1}, *(arg.cumulative_node_complexity for arg in self.args.values()), - self.individual_node_complexity, ) complexity = ( sum_node_complexities( @@ -151,10 +155,11 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: self.args = args self.operators = operators - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = sum_node_complexities( + {self.plan_node_category: 1}, *(arg.cumulative_node_complexity for arg in self.args.values()), - self.individual_node_complexity, ) complexity = ( sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index d30bf8d0ccd..9b49ecb7898 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -24,9 +24,10 @@ def __init__(self, condition: Optional[Expression]) -> None: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # WHEN MATCHED [AND condition] THEN DEL - complexity = self.individual_node_complexity + complexity = {self.plan_node_category: 1} complexity = ( sum_node_complexities(complexity, self.condition.cumulative_node_complexity) if self.condition @@ -42,10 +43,11 @@ def __init__( super().__init__(condition) self.assignments = assignments - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) complexity = sum_node_complexities( - self.individual_node_complexity, + {self.plan_node_category: 1}, *( sum_node_complexities( key_expr.cumulative_node_complexity, @@ -77,10 +79,11 @@ def __init__( self.keys = keys self.values = values - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) complexity = sum_node_complexities( - self.individual_node_complexity, + {self.plan_node_category: 1}, *(key.cumulative_node_complexity for key in self.keys), *(val.cumulative_node_complexity for val in self.values), ) diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 474ac110049..69db3f265ce 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -75,10 +75,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # frame_type BETWEEN lower AND upper return sum_node_complexities( - self.individual_node_complexity, + {self.plan_node_category: 1}, self.lower.cumulative_node_complexity, self.upper.cumulative_node_complexity, ) @@ -103,28 +104,28 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: - complexity = {} + # partition_spec order_by_spec frame_spec + complexity = self.frame_spec.cumulative_node_complexity complexity = ( - sum_node_complexities(complexity, {PlanNodeCategory.PARTITION_BY: 1}) + sum_node_complexities( + complexity, + {PlanNodeCategory.PARTITION_BY: 1}, + *(expr.cumulative_node_complexity for expr in self.partition_spec), + ) if self.partition_spec else complexity ) complexity = ( - sum_node_complexities(complexity, {PlanNodeCategory.ORDER_BY: 1}) + sum_node_complexities( + complexity, + {PlanNodeCategory.ORDER_BY: 1}, + *(expr.cumulative_node_complexity for expr in self.order_spec), + ) if self.order_spec else complexity ) return complexity - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: - # partition_spec order_by_spec frame_spec - return sum_node_complexities( - self.individual_node_complexity, - self.frame_spec.cumulative_node_complexity, - *(expr.cumulative_node_complexity for expr in self.partition_spec), - *(expr.cumulative_node_complexity for expr in self.order_spec), - ) - class WindowExpression(Expression): def __init__( @@ -141,12 +142,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # window_function OVER ( window_spec ) return sum_node_complexities( + {self.plan_node_category: 1}, self.window_function.cumulative_node_complexity, self.window_spec.cumulative_node_complexity, - self.individual_node_complexity, ) @@ -186,12 +188,9 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: if self.ignore_nulls else complexity ) - return complexity - - def calculate_cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: # func_name (expr [, offset] [, default]) [IGNORE NULLS] complexity = sum_node_complexities( - self.individual_node_complexity, self.expr.cumulative_node_complexity + complexity, self.expr.cumulative_node_complexity ) complexity = ( sum_node_complexities(complexity, self.default.cumulative_node_complexity) diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index d18ae0e6c55..ff1b5ab47b9 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -161,9 +161,10 @@ def wrap(*args, **kwargs): api_calls[0][ TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value ] = plan.num_duplicate_nodes - api_calls[0][TelemetryField.QUERY_PLAN_COMPLEXITY.value] = dict( - plan.cumulative_node_complexity - ) + api_calls[0][TelemetryField.QUERY_PLAN_COMPLEXITY.value] = { + key.value: value + for key, value in plan.cumulative_node_complexity.items() + } except Exception: pass args[0]._session._conn._telemetry_client.send_function_usage_telemetry( diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 08374163383..0e8bb0d902d 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -91,7 +91,7 @@ def test_range_statement(session: Session): { PlanNodeCategory.COLUMN: 1, PlanNodeCategory.LITERAL: 3, - PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.FUNCTION: 3, PlanNodeCategory.ORDER_BY: 1, PlanNodeCategory.WINDOW: 1, }, From b12b7f6083a5fc34a6a970e9fde7fb1c6aa089a6 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 13 Jun 2024 14:49:39 -0700 Subject: [PATCH 35/37] fix telemetry test --- tests/integ/test_telemetry.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index 40ec07d9de7..61debb4d131 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -595,7 +595,8 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { "filter": 1, - "low_impact": 5, + "low_impact": 3, + "function": 3, "column": 3, "literal": 5, "window": 1, @@ -616,7 +617,8 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { "filter": 1, - "low_impact": 5, + "low_impact": 3, + "function": 3, "column": 3, "literal": 5, "window": 1, @@ -637,7 +639,8 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { "filter": 1, - "low_impact": 5, + "low_impact": 3, + "function": 3, "column": 3, "literal": 5, "window": 1, @@ -658,7 +661,8 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { "filter": 1, - "low_impact": 5, + "low_impact": 3, + "function": 3, "column": 3, "literal": 5, "window": 1, @@ -679,7 +683,8 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { "filter": 1, - "low_impact": 5, + "low_impact": 3, + "function": 3, "column": 3, "literal": 5, "window": 1, From 3f6c9976fe0ad1968dafabe993b0a0690ac1ed42 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 13 Jun 2024 14:57:59 -0700 Subject: [PATCH 36/37] update comment --- src/snowflake/snowpark/_internal/analyzer/expression.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index f412252f1c3..841d886c605 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -102,7 +102,9 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: @property def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: """Returns the aggregate sum complexity statistic from the subtree rooted at this - expression node. Statistic of current node is included in the final aggregate. + expression node. It is computed by adding all expression attributes of current nodes + and cumulative complexity of all children nodes. To correctly maintain this statistic, + override individual_node_complexity method for the derived Expression class. """ if self._cumulative_node_complexity is None: children = self.children or [] From 8a03068a96a4dc588435b3006690a37213342460 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 14 Jun 2024 10:05:37 -0700 Subject: [PATCH 37/37] fix unit test --- tests/unit/test_query_plan_analysis.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_query_plan_analysis.py b/tests/unit/test_query_plan_analysis.py index 6b396692dff..adac50184e0 100644 --- a/tests/unit/test_query_plan_analysis.py +++ b/tests/unit/test_query_plan_analysis.py @@ -62,7 +62,7 @@ def set_children(node, node_type, children): """ o o - / \\ / \ + / \\ / \ o o x o /|\ o o o -> @@ -99,9 +99,7 @@ def test_selectable_entity_individual_node_complexity(mock_analyzer): def test_select_sql_individual_node_complexity(mock_analyzer): - plan_node = SelectSQL( - "non-select statement", convert_to_select=True, analyzer=mock_analyzer - ) + plan_node = SelectSQL("non-select statement", analyzer=mock_analyzer) assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 1} plan_node = SelectSQL("select 1 as A, 2 as B", analyzer=mock_analyzer)