diff --git a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py index 170760111b7..12c5390e5b5 100644 --- a/src/snowflake/snowpark/_internal/analyzer/cte_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/cte_utils.py @@ -5,7 +5,7 @@ import hashlib import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Optional, Sequence, Set, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Tuple, Union from snowflake.snowpark._internal.analyzer.analyzer_utils import ( SPACE, @@ -24,7 +24,9 @@ TreeNode = Union[SnowflakePlan, Selectable] -def find_duplicate_subtrees(root: "TreeNode") -> Set["TreeNode"]: +def find_duplicate_subtrees( + root: "TreeNode", +) -> Tuple[Set["TreeNode"], Dict["TreeNode", Set["TreeNode"]]]: """ Returns a set containing all duplicate subtrees in query plan tree. The root of a duplicate subtree is defined as a duplicate node, if @@ -79,7 +81,8 @@ def is_duplicate_subtree(node: "TreeNode") -> bool: return False traverse(root) - return {node for node in node_count_map if is_duplicate_subtree(node)} + duplicated_node = {node for node in node_count_map if is_duplicate_subtree(node)} + return duplicated_node, node_parents_map def create_cte_query(root: "TreeNode", duplicate_plan_set: Set["TreeNode"]) -> str: diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 8574a7a97a4..c976b0f977a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -313,6 +313,7 @@ def get_snowflake_plan(self, skip_schema_query) -> SnowflakePlan: df_aliased_col_name_to_real_col_name=self.df_aliased_col_name_to_real_col_name, source_plan=self, placeholder_query=self.placeholder_query, + referenced_ctes=self.referenced_ctes, ) # set api_calls to self._snowflake_plan outside of the above constructor # because the constructor copy api_calls. @@ -373,6 +374,12 @@ def column_states(self, value: ColumnStateDict): """ self._column_states = deepcopy(value) + @property + @abstractmethod + def referenced_ctes(self) -> Set[str]: + """Return the set of ctes referenced by the whole selectable subtree, includes its-self and children""" + pass + class SelectableEntity(Selectable): """Query from a table, view, or any other Snowflake objects. @@ -385,7 +392,8 @@ def __init__( *, analyzer: "Analyzer", ) -> None: - # currently only selecting from a table is supported for this class + # currently only selecting from a table or cte is supported + # to read as entity assert isinstance(entity, SnowflakeTable) super().__init__(analyzer) self.entity = entity @@ -421,6 +429,12 @@ def plan_node_category(self) -> PlanNodeCategory: def query_params(self) -> Optional[Sequence[Any]]: return None + @property + def referenced_ctes(self) -> Set[str]: + # the SelectableEntity only allows select from base table. No + # CTE table will be referred. + return set() + class SelectSQL(Selectable): """Query from a SQL. Mainly used by session.sql()""" @@ -518,6 +532,12 @@ def to_subqueryable(self) -> "SelectSQL": new._api_calls = self._api_calls return new + @property + def referenced_ctes(self) -> Set[str]: + # SelectSQL directly calls sql query, there will be no + # auto created CTE tables referenced + return set() + class SelectSnowflakePlan(Selectable): """Wrap a SnowflakePlan to a subclass of Selectable.""" @@ -578,6 +598,10 @@ def query_params(self) -> Optional[Sequence[Any]]: def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.snowflake_plan.individual_node_complexity + @property + def referenced_ctes(self) -> Set[str]: + return self._snowflake_plan.referenced_ctes + class SelectStatement(Selectable): """The main logic plan to be used by a DataFrame. @@ -712,7 +736,11 @@ def sql_query(self) -> str: self._sql_query = self.from_.sql_query return self._sql_query from_clause = self.from_.sql_in_subquery - if self.analyzer.session._cte_optimization_enabled and self.from_._id: + if ( + self.analyzer.session._cte_optimization_enabled + and (not self.analyzer.session._query_compilation_stage_enabled) + and self.from_._id + ): placeholder = f"{analyzer_utils.LEFT_PARENTHESIS}{self.from_._id}{analyzer_utils.RIGHT_PARENTHESIS}" self._sql_query = self.placeholder_query.replace(placeholder, from_clause) else: @@ -844,6 +872,10 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: ) return complexity + @property + def referenced_ctes(self) -> Set[str]: + return self.from_.referenced_ctes + def to_subqueryable(self) -> "Selectable": """When this SelectStatement's subquery is not subqueryable (can't be used in `from` clause of the sql), convert it to subqueryable and create a new SelectStatement with from_ being the new subqueryable怂 @@ -1169,6 +1201,10 @@ def query_params(self) -> Optional[Sequence[Any]]: def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.snowflake_plan.individual_node_complexity + @property + def referenced_ctes(self) -> Set[str]: + return self._snowflake_plan.referenced_ctes + class SetOperand: def __init__(self, selectable: Selectable, operator: Optional[str] = None) -> None: @@ -1261,6 +1297,11 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # we add #set_operands - 1 additional operators in sql query return {PlanNodeCategory.SET_OPERATION: len(self.set_operands) - 1} + @property + def referenced_ctes(self) -> Set[str]: + # get a union of referenced cte tables from all child nodes + return set().union(*[node.referenced_ctes for node in self._nodes]) + class DeriveColumnDependencyError(Exception): """When deriving column dependencies from the subquery.""" diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 8d33c73927b..c8062330996 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -18,6 +18,7 @@ List, Optional, Sequence, + Set, Tuple, Union, ) @@ -222,6 +223,10 @@ def __init__( # TODO (SNOW-1541096): Remove placeholder_query once CTE is supported with the # new compilation step. placeholder_query: Optional[str] = None, + # This field records all the CTE tables that are referred by the + # current SnowflakePlan tree. This is needed for the final query + # generation to generate the correct sql query with CTE definition. + referenced_ctes: Optional[Set[str]] = None, *, session: "snowflake.snowpark.session.Session", ) -> None: @@ -249,6 +254,9 @@ def __init__( self.placeholder_query = placeholder_query # encode an id for CTE optimization self._id = encode_id(queries[-1].sql, queries[-1].params) + self.referenced_ctes: Set[str] = ( + referenced_ctes.copy() if referenced_ctes else set() + ) self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None def __eq__(self, other: "SnowflakePlan") -> bool: @@ -295,7 +303,13 @@ def children_plan_nodes(self) -> List[Union["Selectable", "SnowflakePlan"]]: def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan": # parameter protection - if not self.session._cte_optimization_enabled: + # the common subquery elimination will be applied if cte_optimization is not enabled + # and the new compilation stage is not enabled. When new compilation stage is enabled, + # the common subquery elimination will be done through the new plan transformation. + if ( + not self.session._cte_optimization_enabled + or self.session._query_compilation_stage_enabled + ): return self # if source_plan or placeholder_query is none, it must be a leaf node, @@ -323,7 +337,7 @@ def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan": return self # if there is no duplicate node, no optimization will be performed - duplicate_plan_set = find_duplicate_subtrees(self) + duplicate_plan_set, _ = find_duplicate_subtrees(self) if not duplicate_plan_set: return self @@ -345,7 +359,7 @@ def with_subqueries(self, subquery_plans: List["SnowflakePlan"]) -> "SnowflakePl for query in plan.queries[:-1]: if query not in pre_queries: pre_queries.append(query) - # when self.schema_query is None, that means no schema query is propogated during + # when self.schema_query is None, that means no schema query is propagated during # the process, there is no need to update the schema query. if (new_schema_query is not None) and (plan.schema_query is not None): new_schema_query = new_schema_query.replace( @@ -404,7 +418,8 @@ def plan_height(self) -> int: @cached_property def num_duplicate_nodes(self) -> int: - return len(find_duplicate_subtrees(self)) + duplicated_nodes, _ = find_duplicate_subtrees(self) + return len(duplicated_nodes) @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: @@ -438,6 +453,7 @@ def __copy__(self) -> "SnowflakePlan": self.df_aliased_col_name_to_real_col_name, session=self.session, placeholder_query=self.placeholder_query, + referenced_ctes=self.referenced_ctes, ) else: return SnowflakePlan( @@ -451,6 +467,7 @@ def __copy__(self) -> "SnowflakePlan": self.df_aliased_col_name_to_real_col_name, session=self.session, placeholder_query=self.placeholder_query, + referenced_ctes=self.referenced_ctes, ) def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006 @@ -478,6 +495,7 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006 # note that there is no copy of the session object, be careful when using the # session object after deepcopy session=self.session, + referenced_ctes=self.referenced_ctes, ) copied_plan._is_valid_for_replacement = True if copied_source_plan: @@ -511,6 +529,11 @@ def build( source_plan: Optional[LogicalPlan], schema_query: Optional[str] = None, is_ddl_on_temp_object: bool = False, + # Whether propagate the referenced ctes from child to the new plan built. + # In general, the referenced should be propagated from child, but for cases + # like SnowflakeCreateTable, the CTEs should not be propagated, because + # the CTEs are already embedded and consumed in the child. + propagate_referenced_ctes: bool = True, ) -> SnowflakePlan: select_child = self.add_result_scan_if_not_select(child) queries = select_child.queries[:-1] + [ @@ -547,6 +570,9 @@ def build( df_aliased_col_name_to_real_col_name=child.df_aliased_col_name_to_real_col_name, session=self.session, placeholder_query=placeholder_query, + referenced_ctes=child.referenced_ctes + if propagate_referenced_ctes + else None, ) @SnowflakePlan.Decorator.wrap_exception @@ -559,27 +585,13 @@ def build_binary( ) -> SnowflakePlan: select_left = self.add_result_scan_if_not_select(left) select_right = self.add_result_scan_if_not_select(right) - queries = ( - select_left.queries[:-1] - + select_right.queries[:-1] - + [ - Query( - sql_generator( - select_left.queries[-1].sql, select_right.queries[-1].sql - ), - params=[ - *select_left.queries[-1].params, - *select_right.queries[-1].params, - ], - ) - ] - ) if self._skip_schema_query: schema_query = None else: left_schema_query = schema_value_statement(select_left.attributes) right_schema_query = schema_value_statement(select_right.attributes) schema_query = sql_generator(left_schema_query, right_schema_query) + placeholder_query = ( sql_generator(select_left._id, select_right._id) if self.session._cte_optimization_enabled @@ -601,15 +613,73 @@ def build_binary( } api_calls = [*select_left.api_calls, *select_right.api_calls] + # This is a temporary workaround for query comparison. The query_id_place_holder + # field of Query be a random generated id if not provided, which could cause the + # comparison of two queries fail even if the sql and is_ddl_on_temp_object + # value is the same. + # TODO (SNOW-1570952): Find a uniform way for the query comparison + def _query_exists(current_query: Query, existing_queries: List[Query]) -> bool: + for existing_query in existing_queries: + if ( + (current_query.sql == existing_query.sql) + and ( + current_query.is_ddl_on_temp_object + == existing_query.is_ddl_on_temp_object + ) + and (current_query.params == existing_query.params) + ): + return True + + return False + + referenced_ctes: Set[str] = set() + if ( + self.session.cte_optimization_enabled + and self.session._query_compilation_stage_enabled + ): + # When the cte optimization and the new compilation stage is enabled, the + # queries, referred cte tables, and post actions propagated from + # left and right can have duplicated queries if there is a common CTE block referenced + # by both left and right. + # Need to do a deduplication to avoid repeated query. + merged_queries = select_left.queries[:-1].copy() + for query in select_right.queries[:-1]: + if not _query_exists(query, merged_queries): + merged_queries.append(copy.copy(query)) + + referenced_ctes.update(select_left.referenced_ctes) + referenced_ctes.update(select_right.referenced_ctes) + + post_actions = select_left.post_actions.copy() + for post_action in select_right.post_actions: + if post_action not in post_actions: + post_actions.append(copy.copy(post_action)) + else: + merged_queries = select_left.queries[:-1] + select_right.queries[:-1] + post_actions = select_left.post_actions + select_right.post_actions + + queries = merged_queries + [ + Query( + sql_generator( + select_left.queries[-1].sql, select_right.queries[-1].sql + ), + params=[ + *select_left.queries[-1].params, + *select_right.queries[-1].params, + ], + ) + ] + return SnowflakePlan( queries, schema_query, - select_left.post_actions + select_right.post_actions, + post_actions, new_expr_to_alias, source_plan, api_calls=api_calls, session=self.session, placeholder_query=placeholder_query, + referenced_ctes=referenced_ctes, ) def query( @@ -862,6 +932,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): ), child, source_plan, + propagate_referenced_ctes=False, ) else: return get_create_and_insert_plan(child, replace=False, error=False) @@ -873,6 +944,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): ), child, source_plan, + propagate_referenced_ctes=False, ) else: return self.build( @@ -888,6 +960,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): child, source_plan, is_ddl_on_temp_object=is_temp_table_type, + propagate_referenced_ctes=False, ) elif mode == SaveMode.OVERWRITE: return self.build( @@ -903,6 +976,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): child, source_plan, is_ddl_on_temp_object=is_temp_table_type, + propagate_referenced_ctes=False, ) elif mode == SaveMode.IGNORE: return self.build( @@ -918,6 +992,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): child, source_plan, is_ddl_on_temp_object=is_temp_table_type, + propagate_referenced_ctes=False, ) elif mode == SaveMode.ERROR_IF_EXISTS: if is_generated: @@ -935,6 +1010,7 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): child, source_plan, is_ddl_on_temp_object=is_temp_table_type, + propagate_referenced_ctes=False, ) def limit( @@ -1013,6 +1089,7 @@ def create_or_replace_view( lambda x: create_or_replace_view_statement(name, x, is_temp, comment), child, source_plan, + propagate_referenced_ctes=False, ) def create_or_replace_dynamic_table( @@ -1295,6 +1372,7 @@ def copy_into_location( query, source_plan, query.schema_query, + propagate_referenced_ctes=False, ) def update( @@ -1316,6 +1394,7 @@ def update( ), source_data, source_plan, + propagate_referenced_ctes=False, ) else: return self.query( @@ -1345,6 +1424,7 @@ def delete( ), source_data, source_plan, + propagate_referenced_ctes=False, ) else: return self.query( @@ -1369,6 +1449,7 @@ def merge( lambda x: merge_statement(table_name, x, join_expr, clauses), source_data, source_plan, + propagate_referenced_ctes=False, ) def lateral( @@ -1432,8 +1513,34 @@ def add_result_scan_if_not_select(self, plan: SnowflakePlan) -> SnowflakePlan: plan.source_plan, api_calls=plan.api_calls, session=self.session, + referenced_ctes=plan.referenced_ctes, + ) + + def with_query_block( + self, name: str, child: SnowflakePlan, source_plan: LogicalPlan + ) -> SnowflakePlan: + if not self._skip_schema_query: + raise ValueError( + "schema query for WithQueryBlock is currently not supported" ) + new_query = project_statement([], name) + + queries = child.queries[:-1] + [Query(sql=new_query)] + # propagate the cte table + referenced_ctes = {name}.union(child.referenced_ctes) + + return SnowflakePlan( + queries, + schema_query=None, + post_actions=child.post_actions, + expr_to_alias=child.expr_to_alias, + source_plan=source_plan, + api_calls=child.api_calls, + session=self.session, + referenced_ctes=referenced_ctes, + ) + class PlanQueryType(Enum): # the queries to execute for the plan diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 93cee3d499b..f3e1eb821a6 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -113,6 +113,22 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return {PlanNodeCategory.COLUMN: 1} +class WithQueryBlock(LogicalPlan): + """ + Logical plan node for common table expression (CTE) like + WITH TEMP_CTE_XXXX AS (SELECT * FROM TEST_TABLE). + + The sql generated for all reference of this block is SELECT * from TEMP_CTE_XXX, + similar as select from a SnowflakeTable. + Note that SnowflakeTable is a leaf node, but this node is not. + """ + + def __init__(self, name: str, child: LogicalPlan) -> None: + super().__init__() + self.name = name + self.children.append(child) + + class SnowflakeValues(LeafNode): def __init__( self, diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 4b11a25c021..88e52c34185 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import copy from typing import Dict, List from snowflake.snowpark._internal.analyzer.snowflake_plan import ( @@ -9,6 +10,11 @@ Query, SnowflakePlan, ) +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan +from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import ( + RepeatedSubqueryElimination, +) +from snowflake.snowpark._internal.compiler.utils import create_query_generator class PlanCompiler: @@ -29,12 +35,13 @@ class PlanCompiler: def __init__(self, plan: SnowflakePlan) -> None: self._plan = plan - def should_apply_optimizations(self) -> bool: + def should_start_query_compilation(self) -> bool: """ Whether optimization should be applied to the plan or not. Optimization can be applied if 1) there is source logical plan attached to the current snowflake plan - 2) optimizations are enabled in the current session, such as cte_optimization_enabled + 2) the query compilation stage is enabled + 3) optimizations are enabled in the current session, such as cte_optimization_enabled Returns @@ -44,17 +51,33 @@ def should_apply_optimizations(self) -> bool: current_session = self._plan.session return ( - self._plan.source_plan is not None - ) and current_session.cte_optimization_enabled + (self._plan.source_plan is not None) + and current_session._query_compilation_stage_enabled + and current_session.cte_optimization_enabled + ) def compile(self) -> Dict[PlanQueryType, List[Query]]: - final_plan = self._plan - if self.should_apply_optimizations(): - # apply optimizations - final_plan = final_plan.replace_repeated_subquery_with_cte() - # TODO: add other optimization steps and code generation step - - return { - PlanQueryType.QUERIES: final_plan.queries, - PlanQueryType.POST_ACTIONS: final_plan.post_actions, - } + if self.should_start_query_compilation(): + # preparation for compilation + # 1. make a copy of the original plan + logical_plans: List[LogicalPlan] = [copy.deepcopy(self._plan)] + # 2. create a code generator with the original plan + query_generator = create_query_generator(self._plan) + + # apply each optimizations if needed + if self._plan.session.cte_optimization_enabled: + repeated_subquery_eliminator = RepeatedSubqueryElimination( + logical_plans, query_generator + ) + logical_plans = repeated_subquery_eliminator.apply() + + # do a final pass of code generation + return query_generator.generate_queries(logical_plans) + else: + final_plan = self._plan + if self._plan.session.cte_optimization_enabled: + final_plan = final_plan.replace_repeated_subquery_with_cte() + return { + PlanQueryType.QUERIES: final_plan.queries, + PlanQueryType.POST_ACTIONS: final_plan.post_actions, + } diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index 0a1980af90b..44b76052e4a 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -1,21 +1,29 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # - +import copy from typing import DefaultDict, Dict, Iterable, List, NamedTuple, Optional from snowflake.snowpark._internal.analyzer.analyzer import Analyzer from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.analyzer.select_statement import Selectable from snowflake.snowpark._internal.analyzer.snowflake_plan import ( + CreateViewCommand, PlanQueryType, Query, SnowflakePlan, SnowflakePlanBuilder, ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( + CopyIntoLocationNode, LogicalPlan, SnowflakeCreateTable, + WithQueryBlock, +) +from snowflake.snowpark._internal.analyzer.table_merge_expression import ( + TableDelete, + TableMerge, + TableUpdate, ) from snowflake.snowpark.session import Session @@ -49,6 +57,13 @@ def __init__( self._snowflake_create_table_plan_info: Optional[ SnowflakeCreateTablePlanInfo ] = snowflake_create_table_plan_info + # Records the definition of all the with query blocks encountered during the code generation. + # This information will be used to generate the final query of a SnowflakePlan with the + # correct CTE definition. + # NOTE: the dict used here is an ordered dict, all with query block definition is recorded in the + # order of when the with query block is visited. The order is important to make sure the dependency + # between the CTE definition is satisfied. + self.resolved_with_query_block: Dict[str, str] = {} def generate_queries( self, logical_plans: List[LogicalPlan] @@ -60,14 +75,21 @@ def generate_queries( ------- """ + from snowflake.snowpark._internal.compiler.utils import ( + get_snowflake_plan_queries, + ) + # generate queries for each logical plan snowflake_plans = [self.resolve(logical_plan) for logical_plan in logical_plans] # merge all results into final set of queries queries = [] post_actions = [] for snowflake_plan in snowflake_plans: - queries.extend(snowflake_plan.queries) - post_actions.extend(snowflake_plan.post_actions) + plan_queries = get_snowflake_plan_queries( + snowflake_plan, self.resolved_with_query_block + ) + queries.extend(plan_queries[PlanQueryType.QUERIES]) + post_actions.extend(plan_queries[PlanQueryType.POST_ACTIONS]) return { PlanQueryType.QUERIES: queries, @@ -85,13 +107,17 @@ def do_resolve_with_resolved_children( assert logical_plan.source_plan is not None # when encounter a SnowflakePlan with no queries, try to re-resolve # the source plan to construct the result - res = self.resolve(logical_plan.source_plan) + res = self.do_resolve(logical_plan.source_plan) resolved_children[logical_plan] = res - return res + resolved_plan = res else: - return logical_plan + resolved_plan = logical_plan + + elif isinstance(logical_plan, SnowflakeCreateTable): + from snowflake.snowpark._internal.compiler.utils import ( + get_snowflake_plan_queries, + ) - if isinstance(logical_plan, SnowflakeCreateTable): # overwrite the SnowflakeCreateTable resolving, because the child # attribute will be pulled directly from the cache resolved_child = resolved_children[logical_plan.children[0]] @@ -106,7 +132,13 @@ def do_resolve_with_resolved_children( == logical_plan.table_name ) - return self.plan_builder.save_as_table( + # update the resolved child + copied_resolved_child = copy.deepcopy(resolved_child) + final_queries = get_snowflake_plan_queries( + copied_resolved_child, self.resolved_with_query_block + ) + copied_resolved_child.queries = final_queries[PlanQueryType.QUERIES] + resolved_plan = self.plan_builder.save_as_table( logical_plan.table_name, logical_plan.column_names, logical_plan.mode, @@ -116,18 +148,65 @@ def do_resolve_with_resolved_children( for x in logical_plan.clustering_exprs ], logical_plan.comment, - resolved_child, + copied_resolved_child, logical_plan, self.session._use_scoped_temp_objects, logical_plan.is_generated, self._snowflake_create_table_plan_info.child_attributes, ) - if isinstance(logical_plan, Selectable): + elif isinstance( + logical_plan, + ( + CreateViewCommand, + TableUpdate, + TableDelete, + TableMerge, + CopyIntoLocationNode, + ), + ): + from snowflake.snowpark._internal.compiler.utils import ( + get_snowflake_plan_queries, + ) + + # for CreateViewCommand, TableUpdate, TableDelete, TableMerge and CopyIntoLocationNode, + # the with definition must be generated before create, update, delete, merge and copy into + # query. + resolved_child = resolved_children[logical_plan.children[0]] + copied_resolved_child = copy.deepcopy(resolved_child) + final_queries = get_snowflake_plan_queries( + copied_resolved_child, self.resolved_with_query_block + ) + copied_resolved_child.queries = final_queries[PlanQueryType.QUERIES] + resolved_children[logical_plan.children[0]] = copied_resolved_child + resolved_plan = super().do_resolve_with_resolved_children( + logical_plan, resolved_children, df_aliased_col_name_to_real_col_name + ) + + elif isinstance(logical_plan, Selectable): # overwrite the Selectable resolving to make sure we are triggering # any schema query build - return logical_plan.get_snowflake_plan(skip_schema_query=True) + resolved_plan = logical_plan.get_snowflake_plan(skip_schema_query=True) - return super().do_resolve_with_resolved_children( - logical_plan, resolved_children, df_aliased_col_name_to_real_col_name - ) + elif isinstance(logical_plan, WithQueryBlock): + resolved_child = resolved_children[logical_plan.children[0]] + # record the CTE definition of the current block + if logical_plan.name not in self.resolved_with_query_block: + self.resolved_with_query_block[ + logical_plan.name + ] = resolved_child.queries[-1].sql + + resolved_plan = self.plan_builder.with_query_block( + logical_plan.name, + resolved_child, + logical_plan, + ) + + else: + resolved_plan = super().do_resolve_with_resolved_children( + logical_plan, resolved_children, df_aliased_col_name_to_real_col_name + ) + + resolved_plan._is_valid_for_replacement = True + + return resolved_plan diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py new file mode 100644 index 00000000000..38e3b72a32b --- /dev/null +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -0,0 +1,152 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from typing import Dict, List, Optional, Set + +from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( + LogicalPlan, + WithQueryBlock, +) +from snowflake.snowpark._internal.compiler.query_generator import QueryGenerator +from snowflake.snowpark._internal.compiler.utils import ( + TreeNode, + replace_child, + update_resolvable_node, +) +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) + + +class RepeatedSubqueryElimination: + """ + Optimization that used eliminate duplicated queries in the plan. + + When the same dataframe is used at multiple places of the plan, the same subquery + will be generated at each place where it is used, this lead to repeated evaluation + of the same subquery, and causes extra performance overhead. This optimization targets + for detecting the common sub-dataframes, and uses CTE to eliminate the repeated + subquery generated. + For example: + df = session.table("test_table") + df1 = df1.select("a", "b") + df2 = df1.union_all(df1) + originally the generated query for df2 is + (select "a", "b" from "test_table") union all (select "a", "b" from "test_table") + after the optimization, the generated query becomes + with temp_cte_xxx as (select "a", "b" from "test_table") + (select * from temp_cte_xxx) union all (select * from select * from temp_cte_xxx) + """ + + # original logical plans to apply the optimization on + _logical_plans: List[LogicalPlan] + _query_generator: QueryGenerator + + def __init__( + self, + logical_plans: List[LogicalPlan], + query_generator: QueryGenerator, + ) -> None: + self._logical_plans = logical_plans + self._query_generator = query_generator + + def apply(self) -> List[LogicalPlan]: + """ + Applies Common SubDataframe elimination on the set of logical plans one after another. + + Returns: + A set of the new LogicalPlans with common sub dataframe deduplicated with CTE node. + """ + final_logical_plans: List[LogicalPlan] = [] + for logical_plan in self._logical_plans: + # NOTE: the current common sub-dataframe elimination relies on the + # fact that all intermediate steps are resolved properly. Here we + # do a pass of resolve of the logical plan to make sure we get a valid + # resolved plan to start the process. + # If the plan is already a resolved plan, this step will be a no-op. + logical_plan = self._query_generator.resolve(logical_plan) + + # apply the CTE optimization on the resolved plan + duplicated_nodes, node_parents_map = find_duplicate_subtrees(logical_plan) + if len(duplicated_nodes) > 0: + deduplicated_plan = self._replace_duplicate_node_with_cte( + logical_plan, duplicated_nodes, node_parents_map + ) + final_logical_plans.append(deduplicated_plan) + else: + final_logical_plans.append(logical_plan) + + # TODO (SNOW-1566363): Add telemetry for CTE + return final_logical_plans + + def _replace_duplicate_node_with_cte( + self, + root: TreeNode, + duplicated_nodes: Set[TreeNode], + node_parents_map: Dict[TreeNode, Set[TreeNode]], + ) -> LogicalPlan: + """ + Replace all duplicated nodes with a WithQueryBlock (CTE node), to enable + query generation with CTEs. + + NOTE, we use stack to perform a post-order traversal instead of recursive call. + The reason of using the stack approach is that chained CTEs have to be built + from bottom (innermost subquery) to top (outermost query). + This function uses an iterative approach to avoid hitting Python's maximum recursion depth limit. + """ + + stack1, stack2 = [root], [] + + while stack1: + node = stack1.pop() + stack2.append(node) + for child in reversed(node.children_plan_nodes): + stack1.append(child) + + # tack node that is already visited to avoid repeated operation on the same node + visited_nodes: Set[TreeNode] = set() + updated_nodes: Set[TreeNode] = set() + + def _update_parents( + node: TreeNode, + should_replace_child: bool, + new_child: Optional[TreeNode] = None, + ) -> None: + parents = node_parents_map[node] + for parent in parents: + if should_replace_child: + assert ( + new_child is not None + ), "no new child is provided for replacement" + replace_child(parent, node, new_child, self._query_generator) + update_resolvable_node(parent, self._query_generator) + updated_nodes.add(parent) + + while stack2: + node = stack2.pop() + if node in visited_nodes: + continue + + # if the node is a duplicated node and deduplication is not done for the node, + # start the deduplication transformation use CTE + if node in duplicated_nodes: + # create a WithQueryBlock node + with_block = WithQueryBlock( + name=random_name_for_temp_object(TempObjectType.CTE), child=node + ) + with_block._is_valid_for_replacement = True + + resolved_with_block = self._query_generator.resolve(with_block) + _update_parents( + node, should_replace_child=True, new_child=resolved_with_block + ) + elif node in updated_nodes: + # if the node is updated, make sure all nodes up to parent is updated + _update_parents(node, should_replace_child=False) + + visited_nodes.add(node) + + return root diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index ed5a038978f..fb76012984b 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -2,18 +2,22 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import copy +from typing import Dict, List, Optional, Union -from typing import Optional - -from snowflake.snowpark._internal.analyzer.analyzer import Analyzer from snowflake.snowpark._internal.analyzer.binary_plan_node import BinaryNode from snowflake.snowpark._internal.analyzer.select_statement import ( Selectable, SelectSnowflakePlan, SelectStatement, + SelectTableFunction, SetStatement, ) -from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.analyzer.snowflake_plan import ( + PlanQueryType, + Query, + SnowflakePlan, +) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( CopyIntoLocationNode, Limit, @@ -31,6 +35,8 @@ SnowflakeCreateTablePlanInfo, ) +TreeNode = Union[SnowflakePlan, Selectable] + def create_query_generator(plan: SnowflakePlan) -> QueryGenerator: """ @@ -58,26 +64,53 @@ def create_query_generator(plan: SnowflakePlan) -> QueryGenerator: return QueryGenerator(plan.session, snowflake_create_table_plan_info) +def resolve_and_update_snowflake_plan( + node: SnowflakePlan, query_generator: QueryGenerator +) -> None: + """ + Re-resolve the current snowflake plan if it has a source plan attached, and update the fields with + newly resolved value. + """ + + if node.source_plan is None: + return + + new_snowflake_plan = query_generator.resolve(node.source_plan) + + # copy over the newly resolved fields to make it an in-place update + node.queries = new_snowflake_plan.queries + node.post_actions = new_snowflake_plan.post_actions + node.expr_to_alias = new_snowflake_plan.expr_to_alias + node.is_ddl_on_temp_object = new_snowflake_plan.is_ddl_on_temp_object + node._output_dict = new_snowflake_plan._output_dict + node.df_aliased_col_name_to_real_col_name = ( + new_snowflake_plan.df_aliased_col_name_to_real_col_name + ) + node.placeholder_query = new_snowflake_plan.placeholder_query + node.referenced_ctes = new_snowflake_plan.referenced_ctes + node._cumulative_node_complexity = new_snowflake_plan._cumulative_node_complexity + + def replace_child( parent: LogicalPlan, old_child: LogicalPlan, new_child: LogicalPlan, - analyzer: Analyzer, + query_generator: QueryGenerator, ) -> None: """ - Helper function to replace the child node in the plan with a new child. + Helper function to replace the child node of a plan node with a new child. Whenever necessary, we convert the new_child into a Selectable or SnowflakePlan based on the parent node type. """ - def to_selectable(plan: LogicalPlan, analyzer: Analyzer) -> Selectable: + def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selectable: """Given a LogicalPlan, convert it to a Selectable.""" if isinstance(plan, Selectable): return plan - snowflake_plan = analyzer.resolve(plan) - return SelectSnowflakePlan(snowflake_plan, analyzer=analyzer) + snowflake_plan = query_generator.resolve(plan) + return SelectSnowflakePlan(snowflake_plan, analyzer=query_generator) if not parent._is_valid_for_replacement: raise ValueError(f"parent node {parent} is not valid for replacement.") @@ -87,15 +120,13 @@ def to_selectable(plan: LogicalPlan, analyzer: Analyzer) -> Selectable: if isinstance(parent, SnowflakePlan): assert parent.source_plan is not None - replace_child(parent.source_plan, old_child, new_child, analyzer) - return + replace_child(parent.source_plan, old_child, new_child, query_generator) - if isinstance(parent, SelectStatement): - parent.from_ = to_selectable(new_child, analyzer) - return + elif isinstance(parent, SelectStatement): + parent.from_ = to_selectable(new_child, query_generator) - if isinstance(parent, SetStatement): - new_child_as_selectable = to_selectable(new_child, analyzer) + elif isinstance(parent, SetStatement): + new_child_as_selectable = to_selectable(new_child, query_generator) parent._nodes = [ node if node != old_child else new_child_as_selectable for node in parent._nodes @@ -103,19 +134,16 @@ def to_selectable(plan: LogicalPlan, analyzer: Analyzer) -> Selectable: for operand in parent.set_operands: if operand.selectable == old_child: operand.selectable = new_child_as_selectable - return - if isinstance(parent, Selectable): - assert parent.snowflake_plan.source_plan is not None - replace_child(parent.snowflake_plan.source_plan, old_child, new_child, analyzer) - return + elif isinstance(parent, Selectable): + assert parent.snowflake_plan is not None + replace_child(parent.snowflake_plan, old_child, new_child, query_generator) - if isinstance(parent, (UnaryNode, Limit, CopyIntoLocationNode)): + elif isinstance(parent, (UnaryNode, Limit, CopyIntoLocationNode)): parent.children = [new_child] parent.child = new_child - return - if isinstance(parent, BinaryNode): + elif isinstance(parent, BinaryNode): parent.children = [ node if node != old_child else new_child for node in parent.children ] @@ -123,29 +151,101 @@ def to_selectable(plan: LogicalPlan, analyzer: Analyzer) -> Selectable: parent.left = new_child if parent.right == old_child: parent.right = new_child - return - if isinstance(parent, SnowflakeCreateTable): + elif isinstance(parent, SnowflakeCreateTable): parent.children = [new_child] parent.query = new_child - return - if isinstance(parent, (TableUpdate, TableDelete)): - snowflake_plan = analyzer.resolve(new_child) + elif isinstance(parent, (TableUpdate, TableDelete)): + snowflake_plan = query_generator.resolve(new_child) parent.children = [snowflake_plan] parent.source_data = snowflake_plan - return - if isinstance(parent, TableMerge): - snowflake_plan = analyzer.resolve(new_child) + elif isinstance(parent, TableMerge): + snowflake_plan = query_generator.resolve(new_child) parent.children = [snowflake_plan] parent.source = snowflake_plan - return - if isinstance(parent, LogicalPlan): + elif isinstance(parent, LogicalPlan): parent.children = [ node if node != old_child else new_child for node in parent.children ] - return - raise ValueError(f"parent type {type(parent)} not supported") + else: + raise ValueError(f"parent type {type(parent)} not supported") + + +def update_resolvable_node( + node: TreeNode, + query_generator: QueryGenerator, +): + """ + Helper function to re-resolve the resolvable node and do an in-place update for cached fields. + The re-resolve is only needed for SnowflakePlan node and Selectable node, because only those nodes + are resolved node with sql query state. + + Note the update is done recursively until it reach to the child to the children_plan_nodes, + this is to make sure all nodes in between current node and child are updated + correctly. For example, with the following plan + SelectSnowflakePlan + | + SnowflakePlan + | + JOIN + resolve_node(SelectSnowflakePlan, query_generator) will resolve both SelectSnowflakePlan and SnowflakePlan nodes. + """ + + if not node._is_valid_for_replacement: + raise ValueError(f"node {node} is not valid for update.") + + if not isinstance(node, (SnowflakePlan, Selectable)): + raise ValueError(f"It is not valid to update node with type {type(node)}.") + + if isinstance(node, SnowflakePlan): + assert node.source_plan is not None + if isinstance(node.source_plan, (SnowflakePlan, Selectable)): + update_resolvable_node(node.source_plan, query_generator) + resolve_and_update_snowflake_plan(node, query_generator) + + elif isinstance(node, (SelectStatement, SetStatement)): + # clean up the cached sql query and snowflake plan to allow + # re-calculation of the sql query and snowflake plan + node._sql_query = None + node._snowflake_plan = None + node.analyzer = query_generator + + elif isinstance(node, (SelectSnowflakePlan, SelectTableFunction)): + assert node.snowflake_plan is not None + update_resolvable_node(node.snowflake_plan, query_generator) + node.analyzer = query_generator + + elif isinstance(node, Selectable): + node.analyzer = query_generator + + +def get_snowflake_plan_queries( + plan: SnowflakePlan, resolved_with_query_blocks: Dict[str, str] +) -> Dict[PlanQueryType, List[Query]]: + + from snowflake.snowpark._internal.analyzer.analyzer_utils import cte_statement + + plan_queries = plan.queries + post_action_queries = plan.post_actions + if len(plan.referenced_ctes) > 0: + # make a copy of the original query to avoid any update to the + # original query object + plan_queries = copy.deepcopy(plan.queries) + post_action_queries = copy.deepcopy(plan.post_actions) + table_names = [] + definition_queries = [] + for name, definition_query in resolved_with_query_blocks.items(): + if name in plan.referenced_ctes: + table_names.append(name) + definition_queries.append(definition_query) + with_query = cte_statement(definition_queries, table_names) + plan_queries[-1].sql = with_query + plan_queries[-1].sql + + return { + PlanQueryType.QUERIES: plan_queries, + PlanQueryType.POST_ACTIONS: post_action_queries, + } diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 6895f4ad512..226827bbfd0 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -194,9 +194,12 @@ _PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME_STRING = ( "PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME" ) +# parameter used to turn off the whole new query compilation stage in one shot. If turned +# off the plan won't go through the extra optimization and query generation steps. +_PYTHON_SNOWPARK_ENABLE_QUERY_COMPILATION_STAGE = ( + "PYTHON_SNOWPARK_ENABLE_COMPILATION_STAGE" +) _PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION_STRING = "PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION" -# TODO (SNOW-1482588): Add parameter for PYTHON_SNOWPARK_ELIMINATE_NUMERIC_SQL_VALUE_CAST_ENABLED -# at server side _PYTHON_SNOWPARK_ELIMINATE_NUMERIC_SQL_VALUE_CAST_ENABLED = ( "PYTHON_SNOWPARK_ELIMINATE_NUMERIC_SQL_VALUE_CAST_ENABLED" ) @@ -547,6 +550,12 @@ def __init__( ) ) + self._query_compilation_stage_enabled: bool = ( + self._conn._get_client_side_session_parameter( + _PYTHON_SNOWPARK_ENABLE_QUERY_COMPILATION_STAGE, False + ) + ) + self._custom_package_usage_config: Dict = {} self._conf = self.RuntimeConfig(self, options or {}) self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 93a90c4495d..ea2a81d4d7e 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -38,13 +38,18 @@ WITH = "WITH" +paramList = [False, True] -@pytest.fixture(autouse=True) -def setup(session): + +@pytest.fixture(params=paramList, autouse=True) +def setup(request, session): is_cte_optimization_enabled = session._cte_optimization_enabled + is_query_compilation_enabled = session._query_compilation_stage_enabled + session._query_compilation_stage_enabled = request.param session._cte_optimization_enabled = True yield session._cte_optimization_enabled = is_cte_optimization_enabled + session._query_compilation_stage_enabled = is_query_compilation_enabled def check_result(session, df, expect_cte_optimized): @@ -186,7 +191,6 @@ def test_same_duplicate_subtree(session): df_result1 = df3.union_all(df3) check_result(session, df_result1, expect_cte_optimized=True) assert count_number_of_ctes(df_result1.queries["queries"][-1]) == 1 - """ root / \ diff --git a/tests/unit/test_replace_child.py b/tests/unit/compiler/test_replace_child_and_update_node.py similarity index 52% rename from tests/unit/test_replace_child.py rename to tests/unit/compiler/test_replace_child_and_update_node.py index d5f37a9b972..0f7781988a7 100644 --- a/tests/unit/test_replace_child.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -9,6 +9,7 @@ from snowflake.snowpark._internal.analyzer.binary_plan_node import Inner, Join, Union from snowflake.snowpark._internal.analyzer.select_statement import ( + Selectable, SelectableEntity, SelectSnowflakePlan, SelectSQL, @@ -32,37 +33,102 @@ TableUpdate, ) from snowflake.snowpark._internal.analyzer.unary_plan_node import Project, Sort -from snowflake.snowpark._internal.compiler.utils import replace_child +from snowflake.snowpark._internal.compiler.query_generator import QueryGenerator +from snowflake.snowpark._internal.compiler.utils import ( + replace_child, + update_resolvable_node, +) old_plan = LogicalPlan() irrelevant_plan = LogicalPlan() -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def new_plan(mock_session): - yield SnowflakeTable(name="table", session=mock_session) - - -def assert_precondition(plan, new_plan, analyzer, using_deep_copy=False): + table_node = SnowflakeTable(name="table", session=mock_session) + table_node._is_valid_for_replacement = True + return table_node + + +def mock_snowflake_plan() -> SnowflakePlan: + fake_snowflake_plan = mock.create_autospec(SnowflakePlan) + fake_snowflake_plan._id = "dummy id" + fake_snowflake_plan.expr_to_alias = {} + fake_snowflake_plan.df_aliased_col_name_to_real_col_name = {} + fake_snowflake_plan.queries = [ + Query("FAKE SQL", query_id_place_holder="query_id_place_holder_abc") + ] + fake_snowflake_plan.post_actions = [] + fake_snowflake_plan.api_calls = [] + fake_snowflake_plan.is_ddl_on_temp_object = False + fake_snowflake_plan._output_dict = [] + fake_snowflake_plan.placeholder_query = None + fake_snowflake_plan.referenced_ctes = {"TEST_CTE"} + fake_snowflake_plan._cumulative_node_complexity = {} + return fake_snowflake_plan + + +@pytest.fixture(scope="function") +def mock_query_generator(mock_session) -> QueryGenerator: + def mock_resolve(x): + snowflake_plan = mock_snowflake_plan() + snowflake_plan.source_plan = x + return snowflake_plan + + fake_query_generator = mock.create_autospec(QueryGenerator) + fake_query_generator.resolve.side_effect = mock_resolve + fake_query_generator.session = mock_session + return fake_query_generator + + +def assert_precondition(plan, new_plan, query_generator, using_deep_copy=False): original_valid_for_replacement = plan._is_valid_for_replacement try: + # verify when parent is not valid for replacement, an error is thrown plan._is_valid_for_replacement = False with pytest.raises(ValueError, match="is not valid for replacement."): - replace_child(plan, irrelevant_plan, new_plan, analyzer) + replace_child(plan, irrelevant_plan, new_plan, query_generator) + update_resolvable_node(plan, query_generator) + + with pytest.raises(ValueError, match="is not valid for update."): + update_resolvable_node(plan, query_generator) + valid_plan = plan if using_deep_copy: - plan = copy.deepcopy(plan) + valid_plan = copy.deepcopy(plan) else: - plan._is_valid_for_replacement = True + valid_plan._is_valid_for_replacement = True with pytest.raises(ValueError, match="is not a child of parent"): - replace_child(plan, irrelevant_plan, new_plan, analyzer) + replace_child(valid_plan, irrelevant_plan, new_plan, query_generator) + + if not isinstance(valid_plan, (SnowflakePlan, Selectable)): + with pytest.raises( + ValueError, match="It is not valid to update node with type" + ): + update_resolvable_node(valid_plan, query_generator) finally: plan._is_valid_for_replacement = original_valid_for_replacement +def verify_snowflake_plan(plan: SnowflakePlan, expected_plan: SnowflakePlan) -> None: + assert plan.queries == expected_plan.queries + assert plan.post_actions == expected_plan.post_actions + assert plan.expr_to_alias == expected_plan.expr_to_alias + assert plan.is_ddl_on_temp_object == expected_plan.is_ddl_on_temp_object + assert plan._output_dict == expected_plan._output_dict + assert ( + plan.df_aliased_col_name_to_real_col_name + == expected_plan.df_aliased_col_name_to_real_col_name + ) + assert plan.placeholder_query == expected_plan.placeholder_query + assert plan.referenced_ctes == expected_plan.referenced_ctes + assert plan._cumulative_node_complexity == expected_plan._cumulative_node_complexity + assert plan.source_plan is not None + + @pytest.mark.parametrize("using_snowflake_plan", [True, False]) -def test_logical_plan(using_snowflake_plan, mock_query, new_plan, mock_analyzer): +def test_logical_plan(using_snowflake_plan, mock_query, new_plan, mock_query_generator): def get_children(plan): if isinstance(plan, SnowflakePlan): return plan.children_plan_nodes @@ -93,8 +159,8 @@ def get_children(plan): else: join_plan = src_join_plan - assert_precondition(join_plan, new_plan, mock_analyzer) - assert_precondition(project_plan, new_plan, mock_analyzer) + assert_precondition(join_plan, new_plan, mock_query_generator) + assert_precondition(project_plan, new_plan, mock_query_generator) if using_snowflake_plan: join_plan = copy.deepcopy(join_plan) @@ -107,11 +173,11 @@ def get_children(plan): assert isinstance(copied_old_plan, LogicalPlan) assert isinstance(copied_project_plan, LogicalPlan) - replace_child(join_plan, copied_old_plan, new_plan, mock_analyzer) + replace_child(join_plan, copied_old_plan, new_plan, mock_query_generator) assert get_children(join_plan) == [new_plan, copied_project_plan] assert project_plan.children == [old_plan] - replace_child(project_plan, old_plan, new_plan, mock_analyzer) + replace_child(project_plan, old_plan, new_plan, mock_query_generator) assert project_plan.children == [new_plan] @@ -123,31 +189,31 @@ def get_children(plan): lambda x: CopyIntoLocationNode(x, "stage_location", copy_options={}), ], ) -def test_unary_plan(plan_initializer, new_plan, mock_analyzer): +def test_unary_plan(plan_initializer, new_plan, mock_query_generator): plan = plan_initializer(old_plan) assert plan.child == old_plan assert plan.children == [old_plan] - assert_precondition(plan, new_plan, mock_analyzer) + assert_precondition(plan, new_plan, mock_query_generator) plan._is_valid_for_replacement = True - replace_child(plan, old_plan, new_plan, mock_analyzer) + replace_child(plan, old_plan, new_plan, mock_query_generator) assert plan.child == new_plan assert plan.children == [new_plan] -def test_binary_plan(new_plan, mock_analyzer): +def test_binary_plan(new_plan, mock_query_generator): left_plan = Project([], LogicalPlan()) plan = Union(left=left_plan, right=old_plan, is_all=False) assert plan.left == left_plan assert plan.right == old_plan - assert_precondition(plan, new_plan, mock_analyzer) + assert_precondition(plan, new_plan, mock_query_generator) plan._is_valid_for_replacement = True - replace_child(plan, old_plan, new_plan, mock_analyzer) + replace_child(plan, old_plan, new_plan, mock_query_generator) assert plan.left == left_plan assert plan.right == new_plan assert plan.children == [left_plan, new_plan] @@ -162,7 +228,7 @@ def test_binary_plan(new_plan, mock_analyzer): ], ) def test_table_delete_update_merge( - plan_initializer, new_plan, mock_analyzer, mock_snowflake_plan + plan_initializer, new_plan, mock_analyzer, mock_query_generator ): def get_source(plan): if hasattr(plan, "source_data"): @@ -175,29 +241,35 @@ def get_source(plan): assert_precondition(plan, new_plan, mock_analyzer) plan._is_valid_for_replacement = True - replace_child(plan, old_plan, new_plan, mock_analyzer) - assert get_source(plan) == mock_snowflake_plan - assert plan.children == [mock_snowflake_plan] - assert mock_snowflake_plan.source_plan == new_plan + replace_child(plan, old_plan, new_plan, mock_query_generator) + assert isinstance(get_source(plan), SnowflakePlan) + assert plan.children == [get_source(plan)] + assert plan.children[0].source_plan == new_plan + verify_snowflake_plan(plan.children[0], mock_snowflake_plan()) -def test_snowflake_create_table(new_plan, mock_analyzer): +def test_snowflake_create_table(new_plan, mock_query_generator): plan = SnowflakeCreateTable(["temp_table"], None, "OVERWRITE", old_plan, "temp") assert plan.query == old_plan assert plan.children == [old_plan] - assert_precondition(plan, new_plan, mock_analyzer) + assert_precondition(plan, new_plan, mock_query_generator) plan._is_valid_for_replacement = True - replace_child(plan, old_plan, new_plan, mock_analyzer) + replace_child(plan, old_plan, new_plan, mock_query_generator) assert plan.query == new_plan assert plan.children == [new_plan] @pytest.mark.parametrize("using_snowflake_plan", [True, False]) def test_selectable_entity( - using_snowflake_plan, mock_session, mock_analyzer, mock_query, new_plan + using_snowflake_plan, + mock_session, + mock_analyzer, + mock_query_generator, + mock_query, + new_plan, ): table = SnowflakeTable(name="table", session=mock_session) plan = SelectableEntity(entity=table, analyzer=mock_analyzer) @@ -214,13 +286,23 @@ def test_selectable_entity( session=mock_session, ) - assert_precondition(plan, new_plan, mock_analyzer, using_deep_copy=True) + assert_precondition(plan, new_plan, mock_query_generator, using_deep_copy=True) # SelectableEntity has no children assert plan.children_plan_nodes == [] + # test update of SelectableEntity + plan = copy.deepcopy(plan) + update_resolvable_node(plan, mock_query_generator) + if using_snowflake_plan: + assert isinstance(plan.source_plan, SelectableEntity) + assert plan.source_plan.analyzer == mock_query_generator + else: + assert plan.analyzer == mock_query_generator @pytest.mark.parametrize("using_snowflake_plan", [True, False]) -def test_select_sql(using_snowflake_plan, mock_session, mock_analyzer, new_plan): +def test_select_sql( + using_snowflake_plan, mock_session, mock_analyzer, mock_query_generator, new_plan +): plan = SelectSQL("FAKE QUERY", analyzer=mock_analyzer) if using_snowflake_plan: plan = SnowflakePlan( @@ -235,14 +317,28 @@ def test_select_sql(using_snowflake_plan, mock_session, mock_analyzer, new_plan) session=mock_session, ) - assert_precondition(plan, new_plan, mock_analyzer, using_deep_copy=True) + assert_precondition(plan, new_plan, mock_query_generator, using_deep_copy=True) # SelectSQL has no children assert plan.children_plan_nodes == [] + # test update of SelectableEntity + plan = copy.deepcopy(plan) + update_resolvable_node(plan, mock_query_generator) + if using_snowflake_plan: + assert isinstance(plan.source_plan, SelectSQL) + assert plan.source_plan.analyzer == mock_query_generator + else: + assert plan.analyzer == mock_query_generator + @pytest.mark.parametrize("using_snowflake_plan", [True, False]) def test_select_snowflake_plan( - using_snowflake_plan, mock_session, mock_analyzer, mock_query, new_plan + using_snowflake_plan, + mock_session, + mock_analyzer, + mock_query_generator, + mock_query, + new_plan, ): project_plan = Project([], old_plan) snowflake_plan = SnowflakePlan( @@ -272,23 +368,41 @@ def test_select_snowflake_plan( session=mock_session, ) - assert_precondition(plan, new_plan, mock_analyzer, using_deep_copy=True) + assert_precondition(plan, new_plan, mock_query_generator, using_deep_copy=True) plan = copy.deepcopy(plan) # deep copy created a copy of old_plan copied_old_plan = plan.children_plan_nodes[0] + if using_snowflake_plan: + copied_project = plan.source_plan._snowflake_plan.source_plan + copied_select_snowflake_plan = plan.source_plan + else: + copied_project = plan._snowflake_plan.source_plan + copied_select_snowflake_plan = plan + + replace_child(plan, copied_old_plan, new_plan, mock_query_generator) + assert copied_project.children == [new_plan] - replace_child(plan, copied_old_plan, new_plan, mock_analyzer) - assert plan.children_plan_nodes == [new_plan] + # verify node update + update_resolvable_node(plan, mock_query_generator) + expected_snowflake_plan_content = mock_snowflake_plan() + verify_snowflake_plan( + copied_select_snowflake_plan._snowflake_plan, expected_snowflake_plan_content + ) + if using_snowflake_plan: + verify_snowflake_plan(plan, expected_snowflake_plan_content) + + # verify the analyzer of selectable is updated to query generator + assert copied_select_snowflake_plan.analyzer == mock_query_generator @pytest.mark.parametrize("using_snowflake_plan", [True, False]) def test_select_statement( using_snowflake_plan, mock_session, + mock_query_generator, mock_analyzer, mock_query, new_plan, - mock_snowflake_plan, ): from_ = SelectSnowflakePlan( SnowflakePlan( @@ -319,17 +433,32 @@ def test_select_statement( session=mock_session, ) - assert_precondition(plan, new_plan, mock_analyzer, using_deep_copy=True) + assert_precondition(plan, new_plan, mock_query_generator, using_deep_copy=True) plan = copy.deepcopy(plan) - replace_child(plan, from_, new_plan, mock_analyzer) + replace_child(plan, from_, new_plan, mock_query_generator) assert len(plan.children_plan_nodes) == 1 - assert plan.children_plan_nodes[0].snowflake_plan == mock_snowflake_plan - assert mock_snowflake_plan.source_plan == new_plan + assert isinstance(plan.children_plan_nodes[0], SelectSnowflakePlan) + assert plan.children_plan_nodes[0]._snowflake_plan.source_plan == new_plan + assert plan.children_plan_nodes[0].analyzer == mock_query_generator + + # verify node update + update_resolvable_node(plan, mock_query_generator) + expected_snowflake_plan_content = mock_snowflake_plan() + verify_snowflake_plan( + plan.children_plan_nodes[0]._snowflake_plan, expected_snowflake_plan_content + ) + if using_snowflake_plan: + verify_snowflake_plan(plan, expected_snowflake_plan_content) @pytest.mark.parametrize("using_snowflake_plan", [True, False]) def test_select_table_function( - using_snowflake_plan, mock_session, mock_analyzer, mock_query, new_plan + using_snowflake_plan, + mock_session, + mock_analyzer, + mock_query_generator, + mock_query, + new_plan, ): project_plan = Project([], old_plan) snowflake_plan = SnowflakePlan( @@ -361,14 +490,32 @@ def test_select_table_function( session=mock_session, ) - assert_precondition(plan, new_plan, mock_analyzer, using_deep_copy=True) + assert_precondition(plan, new_plan, mock_query_generator, using_deep_copy=True) plan = copy.deepcopy(plan) # deep copy created a copy of old_plan copied_old_plan = plan.children_plan_nodes[0] + if using_snowflake_plan: + copied_project = plan.source_plan._snowflake_plan.source_plan + else: + copied_project = plan._snowflake_plan.source_plan + + replace_child(plan, copied_old_plan, new_plan, mock_query_generator) + assert copied_project.children == [new_plan] - replace_child(plan, copied_old_plan, new_plan, mock_analyzer) - assert plan.children_plan_nodes == [new_plan] + # verify node update + update_resolvable_node(plan, mock_query_generator) + expected_snowflake_plan_content = mock_snowflake_plan() + if using_snowflake_plan: + verify_snowflake_plan(plan, expected_snowflake_plan_content) + assert isinstance(plan.source_plan, SelectTableFunction) + assert plan.source_plan.analyzer == mock_query_generator + verify_snowflake_plan( + plan.source_plan.snowflake_plan, expected_snowflake_plan_content + ) + else: + assert plan.analyzer == mock_query_generator + verify_snowflake_plan(plan.snowflake_plan, expected_snowflake_plan_content) @pytest.mark.parametrize("using_snowflake_plan", [True, False]) @@ -376,9 +523,9 @@ def test_set_statement( using_snowflake_plan, mock_session, mock_analyzer, + mock_query_generator, mock_query, new_plan, - mock_snowflake_plan, ): selectable1 = SelectableEntity( SnowflakeTable(name="table1", session=mock_session), analyzer=mock_analyzer @@ -403,19 +550,35 @@ def test_set_statement( assert_precondition(plan, new_plan, mock_analyzer, using_deep_copy=True) plan = copy.deepcopy(plan) - replace_child(plan, selectable1, new_plan, mock_analyzer) + replace_child(plan, selectable1, new_plan, mock_query_generator) assert len(plan.children_plan_nodes) == 2 - assert plan.children_plan_nodes[0].snowflake_plan == mock_snowflake_plan + assert isinstance(plan.children_plan_nodes[0], SelectSnowflakePlan) assert plan.children_plan_nodes[1] == selectable2 - if not using_snowflake_plan: - assert plan.set_operands[0].selectable.snowflake_plan == mock_snowflake_plan - assert mock_snowflake_plan.source_plan == new_plan + + mocked_snowflake_plan = mock_snowflake_plan() + verify_snowflake_plan( + plan.children_plan_nodes[0].snowflake_plan, mocked_snowflake_plan + ) + + update_resolvable_node(plan, mock_query_generator) + if using_snowflake_plan: + copied_set_statement = plan.source_plan + else: + copied_set_statement = plan + + assert copied_set_statement.analyzer == mock_query_generator + assert copied_set_statement._sql_query is None + assert copied_set_statement._snowflake_plan is None + + if using_snowflake_plan: + # verify the snowflake plan is also updated + verify_snowflake_plan(plan, mocked_snowflake_plan) -def test_replace_child_negative(new_plan, mock_analyzer): +def test_replace_child_negative(new_plan, mock_query_generator): mock_parent = mock.Mock() mock_parent._is_valid_for_replacement = True mock_child = LogicalPlan() mock_parent.children_plan_nodes = [mock_child] with pytest.raises(ValueError, match="not supported"): - replace_child(mock_parent, mock_child, new_plan, mock_analyzer) + replace_child(mock_parent, mock_child, new_plan, mock_query_generator) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index 5a5a0fd75d8..bab9d5680b3 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -47,5 +47,5 @@ def test_case2(): @pytest.mark.parametrize("test_case", [test_case1(), test_case2()]) def test_find_duplicate_subtrees(test_case): plan, expected_duplicate_subtree_ids = test_case - duplicate_subtrees = find_duplicate_subtrees(plan) + duplicate_subtrees, _ = find_duplicate_subtrees(plan) assert {node._id for node in duplicate_subtrees} == expected_duplicate_subtree_ids diff --git a/tests/unit/test_dataframe.py b/tests/unit/test_dataframe.py index 94c3ea8cdd8..1048e4bba96 100644 --- a/tests/unit/test_dataframe.py +++ b/tests/unit/test_dataframe.py @@ -111,6 +111,7 @@ def query_result(*args, **kwargs): fake_session = mock.create_autospec(snowflake.snowpark.session.Session) fake_session.sql_simplifier_enabled = sql_simplifier_enabled fake_session._cte_optimization_enabled = False + fake_session._query_compilation_stage_enabled = False fake_session._conn = mock.create_autospec(ServerConnection) fake_session._plan_builder = SnowflakePlanBuilder(fake_session) fake_session._analyzer = Analyzer(fake_session)