diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index a828d706686..6882be1ffea 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -939,7 +939,7 @@ def do_resolve_with_resolved_children( ) if isinstance(logical_plan, UnresolvedRelation): - return self.plan_builder.table(logical_plan.name) + return self.plan_builder.table(logical_plan.name, logical_plan) if isinstance(logical_plan, SnowflakeCreateTable): return self.plan_builder.save_as_table( diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 1e197400e83..3ed969caada 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -8,6 +8,9 @@ Expression, derive_dependent_columns, ) +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, +) class BinaryExpression(Expression): @@ -26,6 +29,10 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT + class BinaryArithmeticExpression(BinaryExpression): pass diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py index 46212fc6113..f7d37c4ae14 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_plan_node.py @@ -2,9 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import List, Optional +from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages @@ -69,7 +73,10 @@ def __init__(self, left: LogicalPlan, right: LogicalPlan) -> None: class SetOperation(BinaryNode): - pass + @property + def plan_node_category(self) -> PlanNodeCategory: + # (left) operator (right) + return PlanNodeCategory.SET_OPERATION class Except(SetOperation): @@ -187,3 +194,33 @@ def __init__( @property def sql(self) -> str: return self.join_type.sql + + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.JOIN + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # SELECT * FROM (left) AS left_alias join_type_sql JOIN (right) AS right_alias match_cond, using_cond, join_cond + complexity = {self.plan_node_category: 1} + if isinstance(self.join_type, UsingJoin) and self.join_type.using_columns: + complexity = sum_node_complexities( + complexity, + {PlanNodeCategory.COLUMN: len(self.join_type.using_columns)}, + ) + complexity = ( + sum_node_complexities( + complexity, self.join_condition.cumulative_node_complexity + ) + if self.join_condition + else complexity + ) + + complexity = ( + sum_node_complexities( + complexity, self.match_condition.cumulative_node_complexity + ) + if self.match_condition + else complexity + ) + return complexity diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 294df47c2a0..03a2fbd68bc 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -4,9 +4,13 @@ import copy import uuid -from typing import TYPE_CHECKING, AbstractSet, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, AbstractSet, Any, Dict, List, Optional, Tuple import snowflake.snowpark._internal.utils +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) if TYPE_CHECKING: from snowflake.snowpark._internal.analyzer.snowflake_plan import ( @@ -58,6 +62,7 @@ def __init__(self, child: Optional["Expression"] = None) -> None: self.nullable = True self.children = [child] if child else None self.datatype: Optional[DataType] = None + self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. @@ -83,6 +88,36 @@ def sql(self) -> str: def __str__(self) -> str: return self.pretty_name + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.OTHERS + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + """Returns the individual contribution of the expression node towards the overall + compilation complexity of the generated sql. + """ + return {self.plan_node_category: 1} + + @property + def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + """Returns the aggregate sum complexity statistic from the subtree rooted at this + expression node. It is computed by adding all expression attributes of current nodes + and cumulative complexity of all children nodes. To correctly maintain this statistic, + override individual_node_complexity method for the derived Expression class. + """ + if self._cumulative_node_complexity is None: + children = self.children or [] + self._cumulative_node_complexity = sum_node_complexities( + self.individual_node_complexity, + *(child.cumulative_node_complexity for child in children), + ) + return self._cumulative_node_complexity + + @cumulative_node_complexity.setter + def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): + self._cumulative_node_complexity = value + class NamedExpression: name: str @@ -108,6 +143,10 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return self.plan.cumulative_node_complexity + class MultipleExpression(Expression): def __init__(self, expressions: List[Expression]) -> None: @@ -117,6 +156,12 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + *(expr.cumulative_node_complexity for expr in self.expressions), + ) + class InExpression(Expression): def __init__(self, columns: Expression, values: List[Expression]) -> None: @@ -127,6 +172,18 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.IN + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + {self.plan_node_category: 1}, + self.columns.cumulative_node_complexity, + *(expr.cumulative_node_complexity for expr in self.values), + ) + class Attribute(Expression, NamedExpression): def __init__(self, name: str, datatype: DataType, nullable: bool = True) -> None: @@ -155,6 +212,10 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.COLUMN + class Star(Expression): def __init__( @@ -167,6 +228,15 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1} + + return sum_node_complexities( + complexity, + *(expr.cumulative_node_complexity for expr in self.expressions), + ) + class UnresolvedAttribute(Expression, NamedExpression): def __init__( @@ -201,6 +271,10 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.COLUMN + class Literal(Expression): def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None: @@ -224,6 +298,10 @@ def __init__(self, value: Any, datatype: Optional[DataType] = None) -> None: else: self.datatype = infer_type(value) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LITERAL + class Interval(Expression): def __init__( @@ -272,6 +350,10 @@ def sql(self) -> str: def __str__(self) -> str: return self.sql + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT + class Like(Expression): def __init__(self, expr: Expression, pattern: Expression) -> None: @@ -282,6 +364,19 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + @property + def plan_node_category(self) -> PlanNodeCategory: + # expr LIKE pattern + return PlanNodeCategory.LOW_IMPACT + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + {self.plan_node_category: 1}, + self.expr.cumulative_node_complexity, + self.pattern.cumulative_node_complexity, + ) + class RegExp(Expression): def __init__(self, expr: Expression, pattern: Expression) -> None: @@ -292,6 +387,19 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + @property + def plan_node_category(self) -> PlanNodeCategory: + # expr REG_EXP pattern + return PlanNodeCategory.LOW_IMPACT + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + {self.plan_node_category: 1}, + self.expr.cumulative_node_complexity, + self.pattern.cumulative_node_complexity, + ) + class Collate(Expression): def __init__(self, expr: Expression, collation_spec: str) -> None: @@ -302,6 +410,17 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @property + def plan_node_category(self) -> PlanNodeCategory: + # expr COLLATE collate_spec + return PlanNodeCategory.LOW_IMPACT + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + {self.plan_node_category: 1}, self.expr.cumulative_node_complexity + ) + class SubfieldString(Expression): def __init__(self, expr: Expression, field: str) -> None: @@ -312,6 +431,18 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @property + def plan_node_category(self) -> PlanNodeCategory: + # the literal corresponds to the contribution from self.field + return PlanNodeCategory.LITERAL + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # self.expr ( self.field ) + return sum_node_complexities( + {self.plan_node_category: 1}, self.expr.cumulative_node_complexity + ) + class SubfieldInt(Expression): def __init__(self, expr: Expression, field: int) -> None: @@ -322,6 +453,18 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @property + def plan_node_category(self) -> PlanNodeCategory: + # the literal corresponds to the contribution from self.field + return PlanNodeCategory.LITERAL + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # self.expr ( self.field ) + return sum_node_complexities( + {self.plan_node_category: 1}, self.expr.cumulative_node_complexity + ) + class FunctionExpression(Expression): def __init__( @@ -354,6 +497,10 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.FUNCTION + class WithinGroup(Expression): def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: @@ -365,6 +512,19 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) + @property + def plan_node_category(self) -> PlanNodeCategory: + # expr WITHIN GROUP (ORDER BY cols) + return PlanNodeCategory.ORDER_BY + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + {self.plan_node_category: 1}, + self.expr.cumulative_node_complexity, + *(col.cumulative_node_complexity for col in self.order_by_cols), + ) + class CaseWhen(Expression): def __init__( @@ -384,6 +544,31 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: exps.append(self.else_value) return derive_dependent_columns(*exps) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.CASE_WHEN + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = sum_node_complexities( + {self.plan_node_category: 1}, + *( + sum_node_complexities( + condition.cumulative_node_complexity, + value.cumulative_node_complexity, + ) + for condition, value in self.branches + ), + ) + complexity = ( + sum_node_complexities( + complexity, self.else_value.cumulative_node_complexity + ) + if self.else_value + else complexity + ) + return complexity + class SnowflakeUDF(Expression): def __init__( @@ -404,6 +589,10 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.FUNCTION + class ListAgg(Expression): def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: @@ -415,6 +604,16 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.FUNCTION + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + {self.plan_node_category: 1}, self.col.cumulative_node_complexity + ) + class ColumnSum(Expression): def __init__(self, exprs: List[Expression]) -> None: @@ -423,3 +622,9 @@ def __init__(self, exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.exprs) + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + *(expr.cumulative_node_complexity for expr in self.exprs) + ) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 653fbde1ca3..84cd63fd87d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -2,12 +2,16 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, List, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) class GroupingSet(Expression): @@ -19,6 +23,10 @@ def __init__(self, group_by_exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT + class Cube(GroupingSet): pass @@ -36,3 +44,19 @@ def __init__(self, args: List[List[Expression]]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + {self.plan_node_category: 1}, + *( + sum_node_complexities( + *(expr.cumulative_node_complexity for expr in arg) + ) + for arg in self.args + ), + ) + + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py new file mode 100644 index 00000000000..057a4c37bc7 --- /dev/null +++ b/src/snowflake/snowpark/_internal/analyzer/query_plan_analysis_utils.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from collections import Counter +from enum import Enum +from typing import Dict + + +class PlanNodeCategory(Enum): + """This enum class is used to account for different types of sql + text generated by expressions and logical plan nodes. A bottom up + aggregation of the number of occurrences of each enum type is + done in Expression and LogicalPlan class to calculate and stat + of overall query complexity in the context of compiling for the + generated sql. + """ + + FILTER = "filter" + ORDER_BY = "order_by" + JOIN = "join" + SET_OPERATION = "set_operation" # UNION, EXCEPT, INTERSECT, UNION ALL + SAMPLE = "sample" + PIVOT = "pivot" + UNPIVOT = "unpivot" + WINDOW = "window" + GROUP_BY = "group_by" + PARTITION_BY = "partition_by" + CASE_WHEN = "case_when" + LITERAL = "literal" # cover all literals like numbers, constant strings, etc + COLUMN = "column" # covers all cases where a table column is referred + FUNCTION = ( + "function" # cover all snowflake built-in function, table functions and UDXFs + ) + IN = "in" + LOW_IMPACT = "low_impact" + OTHERS = "others" + + def __repr__(self) -> str: + return self.name + + +def sum_node_complexities( + *node_complexities: Dict[PlanNodeCategory, int] +) -> Dict[PlanNodeCategory, int]: + """This is a helper function to sum complexity values from all complexity dictionaries. A node + complexity is a dictionary of node category to node count mapping""" + counter_sum = sum( + (Counter(complexity) for complexity in node_complexities), Counter() + ) + return dict(counter_sum) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 833009648b6..a64f3d084cb 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict from copy import copy, deepcopy @@ -21,6 +22,10 @@ import snowflake.snowpark._internal.utils from snowflake.snowpark._internal.analyzer.cte_utils import encode_id +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) from snowflake.snowpark._internal.analyzer.table_function import ( TableFunctionExpression, TableFunctionJoin, @@ -35,8 +40,6 @@ Analyzer, ) # pragma: no cover -import sys - from snowflake.snowpark._internal.analyzer import analyzer_utils from snowflake.snowpark._internal.analyzer.analyzer_utils import ( result_scan_statement, @@ -200,6 +203,7 @@ def __init__( str, Dict[str, str] ] = defaultdict(dict) self._api_calls = api_calls.copy() if api_calls is not None else None + self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None def __eq__(self, other: "Selectable") -> bool: if self._id is not None and other._id is not None: @@ -290,6 +294,19 @@ def plan_height(self) -> int: def num_duplicate_nodes(self) -> int: return self.snowflake_plan.num_duplicate_nodes + @property + def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self._cumulative_node_complexity is None: + self._cumulative_node_complexity = sum_node_complexities( + self.individual_node_complexity, + *(node.cumulative_node_complexity for node in self.children_plan_nodes), + ) + return self._cumulative_node_complexity + + @cumulative_node_complexity.setter + def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): + self._cumulative_node_complexity = value + @property def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: """ @@ -348,6 +365,11 @@ def sql_in_subquery(self) -> str: def schema_query(self) -> str: return self.sql_query + @property + def plan_node_category(self) -> PlanNodeCategory: + # SELECT * FROM entity + return PlanNodeCategory.COLUMN + @property def query_params(self) -> Optional[Sequence[Any]]: return None @@ -403,6 +425,10 @@ def query_params(self) -> Optional[Sequence[Any]]: def schema_query(self) -> str: return self._schema_query + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.COLUMN + def to_subqueryable(self) -> "SelectSQL": """Convert this SelectSQL to a new one that can be used as a subquery. Refer to __init__.""" if self.convert_to_select or is_sql_select_statement(self._sql_query): @@ -465,6 +491,10 @@ def schema_query(self) -> str: def query_params(self) -> Optional[Sequence[Any]]: return self._query_params + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return self.snowflake_plan.individual_node_complexity + class SelectStatement(Selectable): """The main logic plan to be used by a DataFrame. @@ -658,6 +688,61 @@ def schema_query(self) -> str: def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return [self.from_] + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = {} + # projection component + complexity = ( + sum_node_complexities( + complexity, + *( + getattr( + expr, + "cumulative_node_complexity", + {PlanNodeCategory.COLUMN: 1}, + ) # type: ignore + for expr in self.projection + ), + ) + if self.projection + else complexity + ) + + # filter component - add +1 for WHERE clause and sum of expression complexity for where expression + complexity = ( + sum_node_complexities( + complexity, + {PlanNodeCategory.FILTER: 1}, + self.where.cumulative_node_complexity, + ) + if self.where + else complexity + ) + + # order by component - add complexity for each sort expression + complexity = ( + sum_node_complexities( + complexity, + *(expr.cumulative_node_complexity for expr in self.order_by), + {PlanNodeCategory.ORDER_BY: 1}, + ) + if self.order_by + else complexity + ) + + # limit/offset component + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.LOW_IMPACT: 1}) + if self.limit_ + else complexity + ) + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.LOW_IMPACT: 1}) + if self.offset + else complexity + ) + return complexity + def to_subqueryable(self) -> "Selectable": """When this SelectStatement's subquery is not subqueryable (can't be used in `from` clause of the sql), convert it to subqueryable and create a new SelectStatement with from_ being the new subqueryable怂 @@ -960,6 +1045,10 @@ def schema_query(self) -> str: def query_params(self) -> Optional[Sequence[Any]]: return self.snowflake_plan.queries[-1].params + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return self.snowflake_plan.individual_node_complexity + class SetOperand: def __init__(self, selectable: Selectable, operator: Optional[str] = None) -> None: @@ -1039,6 +1128,11 @@ def query_params(self) -> Optional[Sequence[Any]]: def children_plan_nodes(self) -> List[Union["Selectable", SnowflakePlan]]: return self._nodes + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # we add #set_operands - 1 additional operators in sql query + return {PlanNodeCategory.SET_OPERATION: len(self.set_operands) - 1} + class DeriveColumnDependencyError(Exception): """When deriving column dependencies from the subquery.""" diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 3d6941eb915..7d8c5ab7e7f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -21,6 +21,10 @@ Union, ) +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) from snowflake.snowpark._internal.analyzer.table_function import ( GeneratorTableFunction, TableFunctionRelation, @@ -232,6 +236,7 @@ def __init__( self.placeholder_query = placeholder_query # encode an id for CTE optimization self._id = encode_id(queries[-1].sql, queries[-1].params) + self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None def __eq__(self, other: "SnowflakePlan") -> bool: if self._id is not None and other._id is not None: @@ -350,6 +355,25 @@ def plan_height(self) -> int: def num_duplicate_nodes(self) -> int: return len(find_duplicate_subtrees(self)) + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self.source_plan: + return self.source_plan.individual_node_complexity + return {} + + @property + def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self._cumulative_node_complexity is None: + self._cumulative_node_complexity = sum_node_complexities( + self.individual_node_complexity, + *(node.cumulative_node_complexity for node in self.children_plan_nodes), + ) + return self._cumulative_node_complexity + + @cumulative_node_complexity.setter + def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): + self._cumulative_node_complexity = value + def __copy__(self) -> "SnowflakePlan": if self.session._cte_optimization_enabled: return SnowflakePlan( @@ -583,8 +607,8 @@ def large_local_relation_plan( source_plan=source_plan, ) - def table(self, table_name: str) -> SnowflakePlan: - return self.query(project_statement([], table_name), None) + def table(self, table_name: str, source_plan: LogicalPlan) -> SnowflakePlan: + return self.query(project_statement([], table_name), source_plan) def file_operation_plan( self, command: str, file_name: str, stage_location: str, options: Dict[str, str] diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index ee23f391af9..8dba1cf2b0b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -8,6 +8,10 @@ from typing import Any, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Attribute, Expression +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType @@ -23,6 +27,34 @@ class LogicalPlan: def __init__(self) -> None: self.children = [] + self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None + + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.OTHERS + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + """Returns the individual contribution of the logical plan node towards the + overall compilation complexity of the generated sql. + """ + return {self.plan_node_category: 1} + + @property + def cumulative_node_complexity(self) -> Dict[PlanNodeCategory, int]: + """Returns the aggregate sum complexity statistic from the subtree rooted at this + logical plan node. Statistic of current node is included in the final aggregate. + """ + if self._cumulative_node_complexity is None: + self._cumulative_node_complexity = sum_node_complexities( + self.individual_node_complexity, + *(node.cumulative_node_complexity for node in self.children), + ) + return self._cumulative_node_complexity + + @cumulative_node_complexity.setter + def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): + self._cumulative_node_complexity = value class LeafNode(LogicalPlan): @@ -39,12 +71,28 @@ def __init__(self, start: int, end: int, step: int, num_slices: int = 1) -> None self.step = step self.num_slices = num_slices + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (step) + (start) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => count))) + return { + PlanNodeCategory.WINDOW: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.LITERAL: 3, # step, start, count + PlanNodeCategory.COLUMN: 1, # id column + PlanNodeCategory.FUNCTION: 3, # ROW_NUMBER, SEQ, GENERATOR + } + class UnresolvedRelation(LeafNode): def __init__(self, name: str) -> None: super().__init__() self.name = name + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # SELECT * FROM name + return {PlanNodeCategory.COLUMN: 1} + class SnowflakeValues(LeafNode): def __init__( @@ -58,6 +106,15 @@ def __init__( self.data = data self.schema_query = schema_query + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) + # TODO: use ARRAY_BIND_THRESHOLD + return { + PlanNodeCategory.COLUMN: len(self.output), + PlanNodeCategory.LITERAL: len(self.data) * len(self.output), + } + class SaveMode(Enum): APPEND = "append" @@ -71,7 +128,7 @@ class SnowflakeCreateTable(LogicalPlan): def __init__( self, table_name: Iterable[str], - column_names: Optional[Iterable[str]], + column_names: Optional[List[str]], mode: SaveMode, query: Optional[LogicalPlan], table_type: str = "", @@ -88,6 +145,27 @@ def __init__( self.clustering_exprs = clustering_exprs or [] self.comment = comment + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # CREATE OR REPLACE table_type TABLE table_name (col definition) clustering_expr AS SELECT * FROM (query) + complexity = {PlanNodeCategory.COLUMN: 1} + complexity = ( + sum_node_complexities( + complexity, {PlanNodeCategory.COLUMN: len(self.column_names)} + ) + if self.column_names + else complexity + ) + complexity = ( + sum_node_complexities( + complexity, + *(expr.cumulative_node_complexity for expr in self.clustering_exprs), + ) + if self.clustering_exprs + else complexity + ) + return complexity + class Limit(LogicalPlan): def __init__( @@ -99,6 +177,15 @@ def __init__( self.child = child self.children.append(child) + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # for limit and offset + return sum_node_complexities( + {PlanNodeCategory.LOW_IMPACT: 2}, + self.limit_expr.cumulative_node_complexity, + self.offset_expr.cumulative_node_complexity, + ) + class CopyIntoTableNode(LeafNode): def __init__( diff --git a/src/snowflake/snowpark/_internal/analyzer/table_function.py b/src/snowflake/snowpark/_internal/analyzer/table_function.py index 2c6381ed345..1e7d342ce8b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_function.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_function.py @@ -6,6 +6,10 @@ from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -30,6 +34,31 @@ def __init__( self.partition_spec = partition_spec self.order_spec = order_spec + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if not self.over: + return {} + complexity = {PlanNodeCategory.WINDOW: 1} + complexity = ( + sum_node_complexities( + complexity, + *(expr.cumulative_node_complexity for expr in self.partition_spec), + {PlanNodeCategory.PARTITION_BY: 1}, + ) + if self.partition_spec + else complexity + ) + complexity = ( + sum_node_complexities( + complexity, + *(expr.cumulative_node_complexity for expr in self.order_spec), + {PlanNodeCategory.ORDER_BY: 1}, + ) + if self.order_spec + else complexity + ) + return complexity + class TableFunctionExpression(Expression): def __init__( @@ -45,6 +74,10 @@ def __init__( self.aliases = aliases self.api_call_source = api_call_source + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.FUNCTION + class FlattenFunction(TableFunctionExpression): def __init__( @@ -57,6 +90,12 @@ def __init__( self.recursive = recursive self.mode = mode + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + return sum_node_complexities( + {self.plan_node_category: 1}, self.input.cumulative_node_complexity + ) + class PosArgumentsTableFunction(TableFunctionExpression): def __init__( @@ -68,6 +107,21 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = sum_node_complexities( + {self.plan_node_category: 1}, + *(arg.cumulative_node_complexity for arg in self.args), + ) + complexity = ( + sum_node_complexities( + complexity, self.partition_spec.cumulative_node_complexity + ) + if self.partition_spec + else complexity + ) + return complexity + class NamedArgumentsTableFunction(TableFunctionExpression): def __init__( @@ -79,6 +133,21 @@ def __init__( super().__init__(func_name, partition_spec) self.args = args + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = sum_node_complexities( + {self.plan_node_category: 1}, + *(arg.cumulative_node_complexity for arg in self.args.values()), + ) + complexity = ( + sum_node_complexities( + complexity, self.partition_spec.cumulative_node_complexity + ) + if self.partition_spec + else complexity + ) + return complexity + class GeneratorTableFunction(TableFunctionExpression): def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: @@ -86,12 +155,35 @@ def __init__(self, args: Dict[str, Expression], operators: List[str]) -> None: self.args = args self.operators = operators + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = sum_node_complexities( + {self.plan_node_category: 1}, + *(arg.cumulative_node_complexity for arg in self.args.values()), + ) + complexity = ( + sum_node_complexities( + complexity, self.partition_spec.cumulative_node_complexity + ) + if self.partition_spec + else complexity + ) + complexity = sum_node_complexities( + complexity, {PlanNodeCategory.COLUMN: len(self.operators)} + ) + return complexity + class TableFunctionRelation(LogicalPlan): def __init__(self, table_function: TableFunctionExpression) -> None: super().__init__() self.table_function = table_function + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # SELECT * FROM table_function + return self.table_function.cumulative_node_complexity + class TableFunctionJoin(LogicalPlan): def __init__( @@ -107,6 +199,17 @@ def __init__( self.left_cols = left_cols if left_cols is not None else ["*"] self.right_cols = right_cols if right_cols is not None else ["*"] + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # SELECT left_cols, right_cols FROM child as left_alias JOIN table(func(...)) as right_alias + return sum_node_complexities( + { + PlanNodeCategory.COLUMN: len(self.left_cols) + len(self.right_cols), + PlanNodeCategory.JOIN: 1, + }, + self.table_function.cumulative_node_complexity, + ) + class Lateral(LogicalPlan): def __init__( @@ -115,3 +218,11 @@ def __init__( super().__init__() self.children = [child] self.table_function = table_function + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # SELECT * FROM (child), LATERAL table_func_expression + return sum_node_complexities( + {PlanNodeCategory.COLUMN: 1}, + self.table_function.cumulative_node_complexity, + ) diff --git a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py index 2d2554e43a1..9b49ecb7898 100644 --- a/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/table_merge_expression.py @@ -5,6 +5,10 @@ from typing import Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import Expression +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( LogicalPlan, SnowflakePlan, @@ -16,6 +20,21 @@ def __init__(self, condition: Optional[Expression]) -> None: super().__init__() self.condition = condition + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # WHEN MATCHED [AND condition] THEN DEL + complexity = {self.plan_node_category: 1} + complexity = ( + sum_node_complexities(complexity, self.condition.cumulative_node_complexity) + if self.condition + else complexity + ) + return complexity + class UpdateMergeExpression(MergeExpression): def __init__( @@ -24,6 +43,26 @@ def __init__( super().__init__(condition) self.assignments = assignments + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # WHEN MATCHED [AND condition] THEN UPDATE SET COMMA.join(k=v for k,v in assignments) + complexity = sum_node_complexities( + {self.plan_node_category: 1}, + *( + sum_node_complexities( + key_expr.cumulative_node_complexity, + val_expr.cumulative_node_complexity, + ) + for key_expr, val_expr in self.assignments.items() + ), + ) + complexity = ( + sum_node_complexities(complexity, self.condition.cumulative_node_complexity) + if self.condition + else complexity + ) + return complexity + class DeleteMergeExpression(MergeExpression): pass @@ -40,6 +79,21 @@ def __init__( self.keys = keys self.values = values + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # WHEN NOT MATCHED [AND cond] THEN INSERT [(COMMA.join(key))] VALUES (COMMA.join(values)) + complexity = sum_node_complexities( + {self.plan_node_category: 1}, + *(key.cumulative_node_complexity for key in self.keys), + *(val.cumulative_node_complexity for val in self.values), + ) + complexity = ( + sum_node_complexities(complexity, self.condition.cumulative_node_complexity) + if self.condition + else complexity + ) + return complexity + class TableUpdate(LogicalPlan): def __init__( @@ -56,6 +110,24 @@ def __init__( self.source_data = source_data self.children = [source_data] if source_data else [] + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # UPDATE table_name SET COMMA.join(k, v in assignments) [source_data] [WHERE condition] + complexity = sum_node_complexities( + *( + sum_node_complexities( + k.cumulative_node_complexity, v.cumulative_node_complexity + ) + for k, v in self.assignments.items() + ), + ) + complexity = ( + sum_node_complexities(complexity, self.condition.cumulative_node_complexity) + if self.condition + else complexity + ) + return complexity + class TableDelete(LogicalPlan): def __init__( @@ -70,6 +142,11 @@ def __init__( self.source_data = source_data self.children = [source_data] if source_data else [] + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # DELETE FROM table_name [USING source_data] [WHERE condition] + return self.condition.cumulative_node_complexity if self.condition else {} + class TableMerge(LogicalPlan): def __init__( @@ -85,3 +162,11 @@ def __init__( self.join_expr = join_expr self.clauses = clauses self.children = [source] if source else [] + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # MERGE INTO table_name USING (source) ON join_expr clauses + return sum_node_complexities( + self.join_expr.cumulative_node_complexity, + *(clause.cumulative_node_complexity for clause in self.clauses), + ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e8f5ebcd2c1..e5886e11069 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,13 +2,16 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from typing import AbstractSet, Dict, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, derive_dependent_columns, ) +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, +) from snowflake.snowpark.types import DataType @@ -33,6 +36,10 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT + class Cast(UnaryExpression): sql_operator = "CAST" @@ -80,6 +87,11 @@ def __init__(self, child: Expression, name: str) -> None: def __str__(self): return f"{self.child} {self.sql_operator} {self.name}" + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # do not add additional complexity for alias + return {} + class UnresolvedAlias(UnaryExpression, NamedExpression): sql_operator = "AS" @@ -88,3 +100,8 @@ class UnresolvedAlias(UnaryExpression, NamedExpression): def __init__(self, child: Expression) -> None: super().__init__(child) self.name = child.sql + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # this is a wrapper around child + return {} diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index fddf3caa8b3..46ff69498bb 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -4,7 +4,15 @@ from typing import Dict, List, Optional, Union -from snowflake.snowpark._internal.analyzer.expression import Expression, NamedExpression +from snowflake.snowpark._internal.analyzer.expression import ( + Expression, + NamedExpression, + ScalarSubquery, +) +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -29,12 +37,30 @@ def __init__( self.row_count = row_count self.seed = seed + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # SELECT * FROM (child) SAMPLE (probability) -- if probability is provided + # SELECT * FROM (child) SAMPLE (row_count ROWS) -- if not probability but row count is provided + return { + PlanNodeCategory.SAMPLE: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.COLUMN: 1, + } + class Sort(UnaryNode): def __init__(self, order: List[SortOrder], child: LogicalPlan) -> None: super().__init__(child) self.order = order + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # child ORDER BY COMMA.join(order) + return sum_node_complexities( + {PlanNodeCategory.ORDER_BY: 1}, + *(col.cumulative_node_complexity for col in self.order), + ) + class Aggregate(UnaryNode): def __init__( @@ -47,13 +73,41 @@ def __init__( self.grouping_expressions = grouping_expressions self.aggregate_expressions = aggregate_expressions + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self.grouping_expressions: + # GROUP BY grouping_exprs + complexity = sum_node_complexities( + {PlanNodeCategory.GROUP_BY: 1}, + *( + expr.cumulative_node_complexity + for expr in self.grouping_expressions + ), + ) + else: + # LIMIT 1 + complexity = {PlanNodeCategory.LOW_IMPACT: 1} + + complexity = sum_node_complexities( + complexity, + *( + getattr( + expr, + "cumulative_node_complexity", + {PlanNodeCategory.COLUMN: 1}, + ) # type: ignore + for expr in self.aggregate_expressions + ), + ) + return complexity + class Pivot(UnaryNode): def __init__( self, grouping_columns: List[Expression], pivot_column: Expression, - pivot_values: Optional[Union[List[Expression], LogicalPlan]], + pivot_values: Optional[Union[List[Expression], ScalarSubquery]], aggregates: List[Expression], default_on_null: Optional[Expression], child: LogicalPlan, @@ -65,6 +119,46 @@ def __init__( self.aggregates = aggregates self.default_on_null = default_on_null + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + complexity = {} + # child complexity adjustment if grouping cols + if self.grouping_columns and self.aggregates and self.aggregates[0].children: + # for additional projecting cols when grouping cols is not empty + complexity = sum_node_complexities( + self.pivot_column.cumulative_node_complexity, + self.aggregates[0].children[0].cumulative_node_complexity, + *(col.cumulative_node_complexity for col in self.grouping_columns), + ) + + # pivot col + if isinstance(self.pivot_values, ScalarSubquery): + complexity = sum_node_complexities( + complexity, self.pivot_values.cumulative_node_complexity + ) + elif isinstance(self.pivot_values, List): + complexity = sum_node_complexities( + complexity, + *(val.cumulative_node_complexity for val in self.pivot_values), + ) + else: + # if pivot values is None, then we add OTHERS for ANY + complexity = sum_node_complexities( + complexity, {PlanNodeCategory.LOW_IMPACT: 1} + ) + + # aggregate complexity + complexity = sum_node_complexities( + complexity, + *(expr.cumulative_node_complexity for expr in self.aggregates), + ) + + # SELECT * FROM (child) PIVOT (aggregate FOR pivot_col in values) + complexity = sum_node_complexities( + complexity, {PlanNodeCategory.COLUMN: 2, PlanNodeCategory.PIVOT: 1} + ) + return complexity + class Unpivot(UnaryNode): def __init__( @@ -79,6 +173,14 @@ def __init__( self.name_column = name_column self.column_list = column_list + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # SELECT * FROM (child) UNPIVOT (value_column FOR name_column IN (COMMA.join(column_list))) + return sum_node_complexities( + {PlanNodeCategory.UNPIVOT: 1, PlanNodeCategory.COLUMN: 3}, + *(expr.cumulative_node_complexity for expr in self.column_list), + ) + class Rename(UnaryNode): def __init__( @@ -89,18 +191,50 @@ def __init__( super().__init__(child) self.column_map = column_map + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # SELECT * RENAME (before AS after, ...) FROM child + return { + PlanNodeCategory.COLUMN: 1 + len(self.column_map), + PlanNodeCategory.LOW_IMPACT: 1 + len(self.column_map), + } + class Filter(UnaryNode): def __init__(self, condition: Expression, child: LogicalPlan) -> None: super().__init__(child) self.condition = condition + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # child WHERE condition + return sum_node_complexities( + {PlanNodeCategory.FILTER: 1}, + self.condition.cumulative_node_complexity, + ) + class Project(UnaryNode): def __init__(self, project_list: List[NamedExpression], child: LogicalPlan) -> None: super().__init__(child) self.project_list = project_list + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if not self.project_list: + return {PlanNodeCategory.COLUMN: 1} + + return sum_node_complexities( + *( + getattr( + col, + "cumulative_node_complexity", + {PlanNodeCategory.COLUMN: 1}, + ) # type: ignore + for col in self.project_list + ), + ) + class ViewType: def __str__(self): diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 452919f3313..69db3f265ce 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -2,12 +2,16 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, List, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, ) +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder @@ -17,6 +21,10 @@ class SpecialFrameBoundary(Expression): def __init__(self) -> None: super().__init__() + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT + class UnboundedPreceding(SpecialFrameBoundary): sql = "UNBOUNDED PRECEDING" @@ -63,6 +71,19 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.LOW_IMPACT + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # frame_type BETWEEN lower AND upper + return sum_node_complexities( + {self.plan_node_category: 1}, + self.lower.cumulative_node_complexity, + self.upper.cumulative_node_complexity, + ) + class WindowSpecDefinition(Expression): def __init__( @@ -81,6 +102,30 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # partition_spec order_by_spec frame_spec + complexity = self.frame_spec.cumulative_node_complexity + complexity = ( + sum_node_complexities( + complexity, + {PlanNodeCategory.PARTITION_BY: 1}, + *(expr.cumulative_node_complexity for expr in self.partition_spec), + ) + if self.partition_spec + else complexity + ) + complexity = ( + sum_node_complexities( + complexity, + {PlanNodeCategory.ORDER_BY: 1}, + *(expr.cumulative_node_complexity for expr in self.order_spec), + ) + if self.order_spec + else complexity + ) + return complexity + class WindowExpression(Expression): def __init__( @@ -93,6 +138,19 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) + @property + def plan_node_category(self) -> PlanNodeCategory: + return PlanNodeCategory.WINDOW + + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # window_function OVER ( window_spec ) + return sum_node_complexities( + {self.plan_node_category: 1}, + self.window_function.cumulative_node_complexity, + self.window_spec.cumulative_node_complexity, + ) + class RankRelatedFunctionExpression(Expression): sql: str @@ -113,6 +171,34 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) + @property + def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + # for func_name + complexity = {PlanNodeCategory.FUNCTION: 1} + # for offset + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.LITERAL: 1}) + if self.offset + else complexity + ) + + # for ignore nulls + complexity = ( + sum_node_complexities(complexity, {PlanNodeCategory.LOW_IMPACT: 1}) + if self.ignore_nulls + else complexity + ) + # func_name (expr [, offset] [, default]) [IGNORE NULLS] + complexity = sum_node_complexities( + complexity, self.expr.cumulative_node_complexity + ) + complexity = ( + sum_node_complexities(complexity, self.default.cumulative_node_complexity) + if self.default + else complexity + ) + return complexity + class Lag(RankRelatedFunctionExpression): sql = "LAG" diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 370a1340af1..2ce6036da62 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -68,6 +68,7 @@ class TelemetryField(Enum): # dataframe query stats QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" + QUERY_PLAN_COMPLEXITY = "query_plan_complexity" # These DataFrame APIs call other DataFrame APIs @@ -155,6 +156,13 @@ def wrap(*args, **kwargs): api_calls[0][TelemetryField.SQL_SIMPLIFIER_ENABLED.value] = args[ 0 ]._session.sql_simplifier_enabled + try: + api_calls[0][TelemetryField.QUERY_PLAN_COMPLEXITY.value] = { + key.value: value + for key, value in plan.cumulative_node_complexity.items() + } + except Exception: + pass args[0]._session._conn._telemetry_client.send_function_usage_telemetry( f"action_{func.__name__}", TelemetryField.FUNC_CAT_ACTION.value, diff --git a/tests/integ/scala/test_snowflake_plan_suite.py b/tests/integ/scala/test_snowflake_plan_suite.py index e9d9e845fb3..35610cb66cb 100644 --- a/tests/integ/scala/test_snowflake_plan_suite.py +++ b/tests/integ/scala/test_snowflake_plan_suite.py @@ -53,7 +53,7 @@ def test_single_query(session): # build plan plans = session._plan_builder - table_plan = plans.table(table_name) + table_plan = plans.table(table_name, df._plan.source_plan) project = plans.project(["num"], table_plan, None) assert len(project.queries) == 1 diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py new file mode 100644 index 00000000000..0e8bb0d902d --- /dev/null +++ b/tests/integ/test_query_plan_analysis.py @@ -0,0 +1,543 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + + +from typing import Dict + +import pytest + +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, + sum_node_complexities, +) +from snowflake.snowpark._internal.analyzer.select_statement import ( + SET_EXCEPT, + SET_INTERSECT, + SET_UNION, + SET_UNION_ALL, +) +from snowflake.snowpark.dataframe import DataFrame +from snowflake.snowpark.functions import avg, col, lit, seq1, table_function, uniform +from snowflake.snowpark.session import Session +from snowflake.snowpark.window import Window +from tests.utils import Utils + +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="Breaking down queries is done for SQL translation", + run=False, + ) +] + + +@pytest.fixture(autouse=True) +def setup(session): + is_simplifier_enabled = session._sql_simplifier_enabled + session._sql_simplifier_enabled = True + yield + session._sql_simplifier_enabled = is_simplifier_enabled + + +@pytest.fixture(scope="module") +def sample_table(session): + table_name = Utils.random_table_name() + Utils.create_table( + session, table_name, "a int, b int, c int, d int", is_temporary=True + ) + session._run_query( + f"insert into {table_name}(a, b, c, d) values " "(1, 2, 3, 4), (5, 6, 7, 8)" + ) + yield table_name + Utils.drop_table(session, table_name) + + +def get_cumulative_node_complexity(df: DataFrame) -> Dict[str, int]: + return df._plan.cumulative_node_complexity + + +def assert_df_subtree_query_complexity(df: DataFrame, estimate: Dict[str, int]): + assert ( + get_cumulative_node_complexity(df) == estimate + ), f"query = {df.queries['queries'][-1]}" + + +def test_create_dataframe_from_values(session: Session): + df1 = session.create_dataframe([[1], [2], [3]], schema=["a"]) + # SELECT "A" FROM ( SELECT $1 AS "A" FROM VALUES (1 :: INT), (2 :: INT), (3 :: INT)) + assert_df_subtree_query_complexity( + df1, {PlanNodeCategory.LITERAL: 3, PlanNodeCategory.COLUMN: 2} + ) + + df2 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"]) + # SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, 2 :: INT), (3 :: INT, 4 :: INT), (5 :: INT, 6 :: INT)) + assert_df_subtree_query_complexity( + df2, {PlanNodeCategory.LITERAL: 6, PlanNodeCategory.COLUMN: 4} + ) + + +def test_session_table(session: Session, sample_table: str): + # select * from sample_table + df = session.table(sample_table) + assert_df_subtree_query_complexity(df, {PlanNodeCategory.COLUMN: 1}) + + +def test_range_statement(session: Session): + df = session.range(1, 5, 2) + # SELECT ( ROW_NUMBER() OVER ( ORDER BY SEQ8() ) - 1 ) * (2) + (1) AS id FROM ( TABLE (GENERATOR(ROWCOUNT => 2))) + assert_df_subtree_query_complexity( + df, + { + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.LITERAL: 3, + PlanNodeCategory.FUNCTION: 3, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.WINDOW: 1, + }, + ) + + +def test_generator_table_function(session: Session): + df1 = session.generator( + seq1(1).as_("seq"), uniform(1, 10, 2).as_("uniform"), rowcount=150 + ) + assert_df_subtree_query_complexity( + df1, + { + PlanNodeCategory.COLUMN: 2, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.LITERAL: 1, + }, + ) + + df2 = df1.order_by("seq") + # adds SELECT * from () ORDER BY seq ASC NULLS FIRST + assert_df_subtree_query_complexity( + df2, + sum_node_complexities( + get_cumulative_node_complexity(df1), + { + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.OTHERS: 1, + }, + ), + ) + + +def test_join_table_function(session: Session): + df1 = session.sql( + "select 'James' as name, 'address1 address2 address3' as addresses" + ) + # SelectSQL chooses num active columns as the best estimate + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN: 1}) + + split_to_table = table_function("split_to_table") + + # SELECT "SEQ", "INDEX", "VALUE" FROM ( + # SELECT T_RIGHT."SEQ", T_RIGHT."INDEX", T_RIGHT."VALUE" FROM + # (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT + # JOIN TABLE (split_to_table("ADDRESSES", ' ') ) AS T_RIGHT) + df2 = df1.select(split_to_table(col("addresses"), lit(" "))) + assert_df_subtree_query_complexity( + df2, + { + PlanNodeCategory.COLUMN: 8, + PlanNodeCategory.JOIN: 1, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.LITERAL: 1, + }, + ) + + # SELECT T_LEFT.*, T_RIGHT.* FROM (select 'James' as name, 'address1 address2 address3' as addresses) AS T_LEFT + # JOIN TABLE (split_to_table("ADDRESS", ' ') OVER (PARTITION BY "LAST_NAME" ORDER BY "FIRST_NAME" ASC NULLS FIRST)) AS T_RIGHT + df3 = df1.join_table_function( + split_to_table(col("address"), lit(" ")).over( + partition_by="last_name", order_by="first_name" + ) + ) + assert_df_subtree_query_complexity( + df3, + { + PlanNodeCategory.COLUMN: 6, + PlanNodeCategory.JOIN: 1, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.PARTITION_BY: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.WINDOW: 1, + PlanNodeCategory.OTHERS: 1, + }, + ) + + +@pytest.mark.parametrize( + "set_operator", [SET_UNION, SET_UNION_ALL, SET_EXCEPT, SET_INTERSECT] +) +def test_set_operators(session: Session, sample_table: str, set_operator: str): + df1 = session.table(sample_table) + df2 = session.table(sample_table) + if set_operator == SET_UNION: + df = df1.union(df2) + elif set_operator == SET_UNION_ALL: + df = df1.union_all(df2) + elif set_operator == SET_EXCEPT: + df = df1.except_(df2) + else: + df = df1.intersect(df2) + + # ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) set_operator ( SELECT * FROM SNOWPARK_TEMP_TABLE_9DJO2Y35IT) + assert_df_subtree_query_complexity( + df, {PlanNodeCategory.COLUMN: 2, PlanNodeCategory.SET_OPERATION: 1} + ) + + +def test_agg(session: Session, sample_table: str): + df = session.table(sample_table) + df1 = df.agg(avg("a")) + df2 = df.agg(avg("a") + 1) + df3 = df.agg(avg("a"), avg(col("b") + lit(1)).as_("avg_b")) + df4 = df.group_by(["a", "b"]).agg(avg("c")) + + # SELECT avg("A") AS "AVG(A)" FROM ( SELECT * FROM sample_table) LIMIT 1 + assert_df_subtree_query_complexity( + df1, + { + PlanNodeCategory.COLUMN: 2, + PlanNodeCategory.LOW_IMPACT: 1, + PlanNodeCategory.FUNCTION: 1, + }, + ) + # SELECT (avg("A") + 1 :: INT) AS "ADD(AVG(A), LITERAL())" FROM ( SELECT * FROM sample_table) LIMIT 1 + assert_df_subtree_query_complexity( + df2, + { + PlanNodeCategory.COLUMN: 2, + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.LITERAL: 1, + }, + ) + # SELECT avg("A") AS "AVG(A)", avg(("B" + 1 :: INT)) AS "AVG_B" FROM ( SELECT * FROM sample_table) LIMIT 1 + assert_df_subtree_query_complexity( + df3, + { + PlanNodeCategory.COLUMN: 3, + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.FUNCTION: 2, + PlanNodeCategory.LITERAL: 1, + }, + ) + # SELECT "A", "B", avg("C") AS "AVG(C)" FROM ( SELECT * FROM SNOWPARK_TEMP_TABLE_EV1NO4AID6) GROUP BY "A", "B" + assert_df_subtree_query_complexity( + df4, + { + PlanNodeCategory.COLUMN: 6, + PlanNodeCategory.GROUP_BY: 1, + PlanNodeCategory.FUNCTION: 1, + }, + ) + + +def test_window_function(session: Session): + window1 = ( + Window.partition_by("value").order_by("key").rows_between(Window.CURRENT_ROW, 2) + ) + window2 = Window.order_by(col("key").desc()).range_between( + Window.UNBOUNDED_PRECEDING, Window.UNBOUNDED_FOLLOWING + ) + df = session.create_dataframe( + [(1, "1"), (2, "2"), (1, "3"), (2, "4")], schema=["key", "value"] + ) + table_name = Utils.random_table_name() + try: + df.write.save_as_table(table_name, table_type="temp", mode="overwrite") + + df1 = session.table(table_name).select( + avg("value").over(window1).as_("window1") + ) + # SELECT avg("VALUE") OVER (PARTITION BY "VALUE" ORDER BY "KEY" ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING ) AS "WINDOW1" FROM table_name + assert_df_subtree_query_complexity( + df1, + { + PlanNodeCategory.PARTITION_BY: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.WINDOW: 1, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.COLUMN: 4, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.LOW_IMPACT: 2, + PlanNodeCategory.OTHERS: 1, + }, + ) + + # SELECT avg("VALUE") OVER ( ORDER BY "KEY" DESC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS "WINDOW2" FROM ( + # SELECT avg("VALUE") OVER (PARTITION BY "VALUE" ORDER BY "KEY" ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING ) AS "WINDOW1" FROM table_name) + df2 = df1.select(avg("value").over(window2).as_("window2")) + assert_df_subtree_query_complexity( + df2, + sum_node_complexities( + get_cumulative_node_complexity(df1), + { + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.WINDOW: 1, + PlanNodeCategory.FUNCTION: 1, + PlanNodeCategory.COLUMN: 2, + PlanNodeCategory.LOW_IMPACT: 3, + PlanNodeCategory.OTHERS: 1, + }, + ), + ) + finally: + Utils.drop_table(session, table_name) + + +def test_join_statement(session: Session, sample_table: str): + # SELECT * FROM table + df1 = session.table(sample_table) + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN: 1}) + # SELECT A, B, E FROM (SELECT $1 AS "A", $2 AS "B", $3 AS "E" FROM VALUES (1 :: INT, 2 :: INT, 5 :: INT), (3 :: INT, 4 :: INT, 9 :: INT)) + df2 = session.create_dataframe([[1, 2, 5], [3, 4, 9]], schema=["a", "b", "e"]) + assert_df_subtree_query_complexity( + df2, {PlanNodeCategory.COLUMN: 6, PlanNodeCategory.LITERAL: 6} + ) + + df3 = df1.join(df2) + # SELECT * FROM (( SELECT "A" AS "l_fkl0_A", "B" AS "l_fkl0_B", "C" AS "C", "D" AS "D" FROM sample_table) AS SNOWPARK_LEFT + # INNER JOIN ( + # SELECT "A" AS "r_co85_A", "B" AS "r_co85_B", "E" AS "E" FROM ( + # SELECT $1 AS "A", $2 AS "B", $3 AS "E" FROM VALUES (1 :: INT, 2 :: INT, 5 :: INT), (3 :: INT, 4 :: INT, 9 :: INT))) AS SNOWPARK_RIGHT) + assert_df_subtree_query_complexity( + df3, + { + PlanNodeCategory.COLUMN: 11, + PlanNodeCategory.LITERAL: 6, + PlanNodeCategory.JOIN: 1, + }, + ) + + df4 = df1.join(df2, on=((df1["a"] == df2["a"]) & (df1["b"] == df2["b"]))) + # SELECT * FROM ((ch1) AS SNOWPARK_LEFT INNER JOIN ( ch2) AS SNOWPARK_RIGHT ON (("l_k7b8_A" = "r_e09m_A") AND ("l_k7b8_B" = "r_e09m_B"))) + assert_df_subtree_query_complexity( + df4, + sum_node_complexities( + get_cumulative_node_complexity(df3), + {PlanNodeCategory.COLUMN: 4, PlanNodeCategory.LOW_IMPACT: 3}, + ), + ) + + df5 = df1.join(df2, using_columns=["a", "b"]) + # SELECT * FROM ( (ch1) AS SNOWPARK_LEFT INNER JOIN (ch2) AS SNOWPARK_RIGHT USING (a, b)) + assert_df_subtree_query_complexity( + df5, + sum_node_complexities( + get_cumulative_node_complexity(df3), {PlanNodeCategory.COLUMN: 2} + ), + ) + + +def test_pivot(session: Session): + try: + session.sql( + """create or replace temp table monthly_sales(empid int, amount int, month text) + as select * from values + (1, 10000, 'JAN'), + (1, 400, 'JAN'), + (2, 4500, 'JAN'), + (2, 35000, 'JAN'), + (1, 5000, 'FEB'), + (1, 3000, 'FEB'), + (2, 200, 'FEB')""" + ).collect() + + df_pivot1 = ( + session.table("monthly_sales").pivot("month", ["JAN", "FEB"]).sum("amount") + ) + # SELECT * FROM ( SELECT * FROM monthly_sales) PIVOT (sum("AMOUNT") FOR "MONTH" IN ('JAN', 'FEB')) + assert_df_subtree_query_complexity( + df_pivot1, + { + PlanNodeCategory.PIVOT: 1, + PlanNodeCategory.COLUMN: 4, + PlanNodeCategory.LITERAL: 2, + PlanNodeCategory.FUNCTION: 1, + }, + ) + + df_pivot2 = ( + session.table("monthly_sales") + .pivot("month", ["JAN", "FEB", "MARCH"]) + .sum("amount") + ) + # SELECT * FROM ( SELECT * FROM monthly_sales) PIVOT (sum("AMOUNT") FOR "MONTH" IN ('JAN', 'FEB', 'MARCH')) + assert_df_subtree_query_complexity( + df_pivot2, + { + PlanNodeCategory.PIVOT: 1, + PlanNodeCategory.COLUMN: 4, + PlanNodeCategory.LITERAL: 3, + PlanNodeCategory.FUNCTION: 1, + }, + ) + finally: + Utils.drop_table(session, "monthly_sales") + + +def test_unpivot(session: Session): + try: + session.sql( + """create or replace temp table sales_for_month(empid int, dept varchar, jan int, feb int) + as select * from values + (1, 'electronics', 100, 200), + (2, 'clothes', 100, 300)""" + ).collect() + + df_unpivot1 = session.table("sales_for_month").unpivot( + "sales", "month", ["jan", "feb"] + ) + # SELECT * FROM ( SELECT * FROM (sales_for_month)) UNPIVOT (sales FOR month IN ("JAN", "FEB")) + assert_df_subtree_query_complexity( + df_unpivot1, + {PlanNodeCategory.UNPIVOT: 1, PlanNodeCategory.COLUMN: 6}, + ) + finally: + Utils.drop_table(session, "sales_for_month") + + +def test_sample(session: Session, sample_table): + df = session.table(sample_table) + df_sample_frac = df.sample(0.5) + # SELECT * FROM ( SELECT * FROM (sample_table)) SAMPLE (50.0) + assert_df_subtree_query_complexity( + df_sample_frac, + { + PlanNodeCategory.SAMPLE: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.COLUMN: 2, + }, + ) + + df_sample_rows = df.sample(n=1) + # SELECT * FROM ( SELECT * FROM (sample_table)) SAMPLE (1 ROWS) + assert_df_subtree_query_complexity( + df_sample_rows, + { + PlanNodeCategory.SAMPLE: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.COLUMN: 2, + }, + ) + + +def test_select_statement_with_multiple_operations(session: Session, sample_table: str): + df = session.table(sample_table) + + # add select + # SELECT "A", "B", "C", "D" FROM sample_table + # note that column stat is 5 even though selected columns is 4. This is because we count 1 column + # from select * from sample_table which is flattened out. This is a known limitation but is okay + # since we are not off my much + df1 = df.select(df["*"]) + assert_df_subtree_query_complexity(df1, {PlanNodeCategory.COLUMN: 5}) + + # SELECT "A", "B", "C" FROM sample_table + df2 = df1.select("a", "b", "c") + assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN: 4}) + + # 1 less active column + df3 = df2.select("b", "c") + assert_df_subtree_query_complexity(df3, {PlanNodeCategory.COLUMN: 3}) + + # add sort + # for additional ORDER BY "B" ASC NULLS FIRST + df4 = df3.sort(col("b").asc()) + assert_df_subtree_query_complexity( + df4, + sum_node_complexities( + get_cumulative_node_complexity(df3), + { + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.ORDER_BY: 1, + PlanNodeCategory.OTHERS: 1, + }, + ), + ) + + # for additional ,"C" ASC NULLS FIRST + df5 = df4.sort(col("c").desc()) + assert_df_subtree_query_complexity( + df5, + sum_node_complexities( + get_cumulative_node_complexity(df4), + {PlanNodeCategory.COLUMN: 1, PlanNodeCategory.OTHERS: 1}, + ), + ) + + # add filter + # for WHERE ("B" > 2) + df6 = df5.filter(col("b") > 2) + assert_df_subtree_query_complexity( + df6, + sum_node_complexities( + get_cumulative_node_complexity(df5), + { + PlanNodeCategory.FILTER: 1, + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.LOW_IMPACT: 1, + }, + ), + ) + + # for filter - AND ("C" > 3) + df7 = df6.filter(col("c") > 3) + assert_df_subtree_query_complexity( + df7, + sum_node_complexities( + get_cumulative_node_complexity(df6), + { + PlanNodeCategory.COLUMN: 1, + PlanNodeCategory.LITERAL: 1, + PlanNodeCategory.LOW_IMPACT: 2, + }, + ), + ) + + # add set operations + df8 = df3.union_all(df4).union_all(df5) + assert_df_subtree_query_complexity( + df8, + sum_node_complexities( + *(get_cumulative_node_complexity(df) for df in [df3, df4, df5]), + {PlanNodeCategory.SET_OPERATION: 2}, + ), + ) + + # + 2 for 2 unions, 30 for sum ob individual df complexity + df9 = df8.union_all(df6).union_all(df7) + assert_df_subtree_query_complexity( + df9, + sum_node_complexities( + *(get_cumulative_node_complexity(df) for df in [df6, df7, df8]), + {PlanNodeCategory.SET_OPERATION: 2}, + ), + ) + + # for limit + df10 = df9.limit(2) + assert_df_subtree_query_complexity( + df10, + sum_node_complexities( + get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT: 1} + ), + ) + + # for offset + df11 = df9.limit(3, offset=1) + assert_df_subtree_query_complexity( + df11, + sum_node_complexities( + get_cumulative_node_complexity(df9), {PlanNodeCategory.LOW_IMPACT: 2} + ), + ) diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index 4ba51c3b8d3..f5ba026a813 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -590,6 +590,15 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "query_plan_complexity": { + "filter": 1, + "low_impact": 3, + "function": 3, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -601,6 +610,15 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "query_plan_complexity": { + "filter": 1, + "low_impact": 3, + "function": 3, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -612,6 +630,15 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "query_plan_complexity": { + "filter": 1, + "low_impact": 3, + "function": 3, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -623,6 +650,15 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "query_plan_complexity": { + "filter": 1, + "low_impact": 3, + "function": 3, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -634,6 +670,15 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "query_plan_complexity": { + "filter": 1, + "low_impact": 3, + "function": 3, + "column": 3, + "literal": 5, + "window": 1, + "order_by": 1, + }, }, {"name": "DataFrame.filter"}, {"name": "DataFrame.filter"}, @@ -768,6 +813,7 @@ def test_dataframe_stat_functions_api_calls(session): { "name": "Session.create_dataframe[values]", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "query_plan_complexity": {"group_by": 1, "column": 6, "literal": 48}, }, { "name": "DataFrameStatFunctions.crosstab", @@ -783,6 +829,7 @@ def test_dataframe_stat_functions_api_calls(session): { "name": "Session.create_dataframe[values]", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "query_plan_complexity": {"group_by": 1, "column": 6, "literal": 48}, } ] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 6419b77d42f..aab964bad09 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,7 +8,10 @@ from snowflake.connector import SnowflakeConnection from snowflake.connector.cursor import SnowflakeCursor +from snowflake.snowpark._internal.analyzer.analyzer import Analyzer +from snowflake.snowpark._internal.analyzer.snowflake_plan import Query from snowflake.snowpark._internal.server_connection import ServerConnection +from snowflake.snowpark.session import Session @pytest.fixture @@ -21,3 +24,25 @@ def mock_server_connection() -> ServerConnection: ) fake_snowflake_connection.is_closed.return_value = False return ServerConnection({}, fake_snowflake_connection) + + +@pytest.fixture(scope="module") +def mock_analyzer() -> Analyzer: + fake_analyzer = mock.create_autospec(Analyzer) + return fake_analyzer + + +@pytest.fixture(scope="module") +def mock_session(mock_analyzer) -> Session: + fake_session = mock.create_autospec(Session) + fake_session._analyzer = mock_analyzer + mock_analyzer.session = fake_session + return fake_session + + +@pytest.fixture(scope="module") +def mock_query(): + fake_query = mock.create_autospec(Query) + fake_query.sql = "dummy sql" + fake_query.params = "dummy params" + return fake_query diff --git a/tests/unit/test_query_plan_analysis.py b/tests/unit/test_query_plan_analysis.py new file mode 100644 index 00000000000..adac50184e0 --- /dev/null +++ b/tests/unit/test_query_plan_analysis.py @@ -0,0 +1,183 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from unittest import mock + +import pytest + +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + EXCEPT, + INTERSECT, + UNION, + UNION_ALL, +) +from snowflake.snowpark._internal.analyzer.expression import Expression, NamedExpression +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, +) +from snowflake.snowpark._internal.analyzer.select_statement import ( + Selectable, + SelectableEntity, + SelectSnowflakePlan, + SelectSQL, + SelectStatement, + SelectTableFunction, + SetOperand, + SetStatement, +) +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan +from snowflake.snowpark._internal.analyzer.table_function import TableFunctionExpression +from snowflake.snowpark._internal.analyzer.unary_plan_node import Project + + +@pytest.mark.parametrize("node_type", [LogicalPlan, SnowflakePlan, Selectable]) +def test_assign_custom_cumulative_node_complexity( + mock_session, mock_analyzer, mock_query, node_type +): + def get_node_for_type(node_type): + if node_type == LogicalPlan: + return LogicalPlan() + if node_type == SnowflakePlan: + return SnowflakePlan( + [mock_query], "", source_plan=LogicalPlan(), session=mock_session + ) + return SelectSnowflakePlan( + SnowflakePlan( + [mock_query], "", source_plan=LogicalPlan(), session=mock_session + ), + analyzer=mock_analyzer, + ) + + def set_children(node, node_type, children): + if node_type == LogicalPlan: + node.children = children + elif node_type == SnowflakePlan: + node.source_plan.children = children + else: + node.snowflake_plan.source_plan.children = children + + nodes = [get_node_for_type(node_type) for _ in range(7)] + + """ + o o + / \\ / \ + o o x o + /|\ + o o o -> + | + o + """ + set_children(nodes[0], node_type, [nodes[1], nodes[2]]) + set_children(nodes[1], node_type, [nodes[3], nodes[4], nodes[5]]) + set_children(nodes[2], node_type, []) + set_children(nodes[3], node_type, []) + set_children(nodes[4], node_type, [nodes[6]]) + set_children(nodes[5], node_type, []) + set_children(nodes[6], node_type, []) + + assert nodes[0].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 7} + assert nodes[1].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 5} + assert nodes[2].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} + assert nodes[3].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} + assert nodes[4].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 2} + assert nodes[5].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} + assert nodes[6].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} + + nodes[1].cumulative_node_complexity = {PlanNodeCategory.COLUMN: 1} + + # assert that only value that is reset is changed + assert nodes[0].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 7} + assert nodes[1].cumulative_node_complexity == {PlanNodeCategory.COLUMN: 1} + assert nodes[2].cumulative_node_complexity == {PlanNodeCategory.OTHERS: 1} + + +def test_selectable_entity_individual_node_complexity(mock_analyzer): + plan_node = SelectableEntity(entity_name="dummy entity", analyzer=mock_analyzer) + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 1} + + +def test_select_sql_individual_node_complexity(mock_analyzer): + plan_node = SelectSQL("non-select statement", analyzer=mock_analyzer) + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 1} + + plan_node = SelectSQL("select 1 as A, 2 as B", analyzer=mock_analyzer) + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 1} + + +def test_select_snowflake_plan_individual_node_complexity( + mock_session, mock_analyzer, mock_query +): + source_plan = Project([NamedExpression(), NamedExpression()], LogicalPlan()) + snowflake_plan = SnowflakePlan( + [mock_query], "", source_plan=source_plan, session=mock_session + ) + plan_node = SelectSnowflakePlan(snowflake_plan, analyzer=mock_analyzer) + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 2} + + +@pytest.mark.parametrize( + "attribute,value,expected_stat", + [ + ("projection", [NamedExpression()], {PlanNodeCategory.COLUMN: 1}), + ("projection", [Expression()], {PlanNodeCategory.OTHERS: 1}), + ( + "order_by", + [Expression()], + {PlanNodeCategory.OTHERS: 1, PlanNodeCategory.ORDER_BY: 1}, + ), + ( + "where", + Expression(), + {PlanNodeCategory.OTHERS: 1, PlanNodeCategory.FILTER: 1}, + ), + ("limit_", 10, {PlanNodeCategory.LOW_IMPACT: 1}), + ("offset", 2, {PlanNodeCategory.LOW_IMPACT: 1}), + ], +) +def test_select_statement_individual_node_complexity( + mock_analyzer, attribute, value, expected_stat +): + from_ = mock.create_autospec(Selectable) + from_.pre_actions = None + from_.post_actions = None + from_.expr_to_alias = {} + from_.df_aliased_col_name_to_real_col_name = {} + + plan_node = SelectStatement(from_=from_, analyzer=mock_analyzer) + setattr(plan_node, attribute, value) + assert plan_node.individual_node_complexity == expected_stat + + +def test_select_table_function_individual_node_complexity( + mock_analyzer, mock_session, mock_query +): + func_expr = mock.create_autospec(TableFunctionExpression) + source_plan = Project([NamedExpression(), NamedExpression()], LogicalPlan()) + snowflake_plan = SnowflakePlan( + [mock_query], "", source_plan=source_plan, session=mock_session + ) + + def mocked_resolve(*args, **kwargs): + return snowflake_plan + + with mock.patch.object(mock_analyzer, "resolve", side_effect=mocked_resolve): + plan_node = SelectTableFunction(func_expr, analyzer=mock_analyzer) + assert plan_node.individual_node_complexity == {PlanNodeCategory.COLUMN: 2} + + +@pytest.mark.parametrize("set_operator", [UNION, UNION_ALL, INTERSECT, EXCEPT]) +def test_set_statement_individual_node_complexity(mock_analyzer, set_operator): + mock_selectable = mock.create_autospec(Selectable) + mock_selectable.pre_actions = None + mock_selectable.post_actions = None + mock_selectable.expr_to_alias = {} + mock_selectable.df_aliased_col_name_to_real_col_name = {} + set_operands = [ + SetOperand(mock_selectable, set_operator), + SetOperand(mock_selectable, set_operator), + ] + plan_node = SetStatement(*set_operands, analyzer=mock_analyzer) + + assert plan_node.individual_node_complexity == {PlanNodeCategory.SET_OPERATION: 1}