diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index c075b4fbc77..6075c5dbd87 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -45,9 +45,41 @@ def find_duplicate_subtrees( This function is used to only include nodes that should be converted to CTEs. """ - id_count_map = defaultdict(int) + id_node_map = defaultdict(list) id_parents_map = defaultdict(set) - id_complexity_map = defaultdict(int) + + from snowflake.snowpark._internal.analyzer.select_statement import ( + Selectable, + SelectStatement, + SelectableEntity, + SelectSnowflakePlan, + ) + from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan + + def is_simple_select_entity(node: "TreeNode") -> bool: + """ + Check if the current node is a simple select on top of a SelectEntity, for example: + select * from TABLE. This check only works with selectable when sql simplifier is enabled. + """ + if isinstance(node, SelectableEntity): + return True + if ( + isinstance(node, SelectStatement) + and (node.projection is None) + and isinstance(node.from_, SelectableEntity) + ): + return True + if ( + isinstance(node, SnowflakePlan) + and (node.source_plan is not None) + and isinstance(node.source_plan, (SnowflakePlan, Selectable)) + ): + return is_simple_select_entity(node.source_plan) + + if isinstance(node, SelectSnowflakePlan): + return is_simple_select_entity(node.snowflake_plan) + + return False def traverse(root: "TreeNode") -> None: """ @@ -57,15 +89,7 @@ def traverse(root: "TreeNode") -> None: while len(current_level) > 0: next_level = [] for node in current_level: - id_count_map[node.encoded_node_id_with_query] += 1 - if propagate_complexity_hist and ( - node.encoded_node_id_with_query not in id_complexity_map - ): - # if propagate_complexity_hist is true, and the complexity score is not - # recorded for the current node id, record the complexity - id_complexity_map[ - node.encoded_node_id_with_query - ] = get_complexity_score(node) + id_node_map[node.encoded_node_id_with_query].append(node) for child in node.children_plan_nodes: id_parents_map[child.encoded_node_id_with_query].add( node.encoded_node_id_with_query @@ -77,13 +101,15 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: # when a sql query is a select statement, its encoded_node_id_with_query # contains _, which is used to separate the query id and node type name. is_valid_candidate = "_" in encoded_node_id_with_query - if not is_valid_candidate: + if not is_valid_candidate or is_simple_select_entity( + id_node_map[encoded_node_id_with_query][0] + ): return False - is_duplicate_node = id_count_map[encoded_node_id_with_query] > 1 + is_duplicate_node = len(id_node_map[encoded_node_id_with_query]) > 1 if is_duplicate_node: is_any_parent_unique_node = any( - id_count_map[id] == 1 + len(id_node_map[id]) == 1 for id in id_parents_map[encoded_node_id_with_query] ) if is_any_parent_unique_node: @@ -97,7 +123,7 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: traverse(root) duplicated_node_ids = { encoded_node_id_with_query - for encoded_node_id_with_query in id_count_map + for encoded_node_id_with_query in id_node_map if is_duplicate_subtree(encoded_node_id_with_query) } @@ -105,7 +131,7 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: return ( duplicated_node_ids, get_duplicated_node_complexity_distribution( - duplicated_node_ids, id_complexity_map, id_count_map + duplicated_node_ids, id_node_map ), ) else: @@ -114,8 +140,7 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: def get_duplicated_node_complexity_distribution( duplicated_node_id_set: Set[str], - id_complexity_map: Dict[str, int], - id_count_map: Dict[str, int], + id_node_map: Dict[str, List["TreeNode"]], ) -> List[int]: """ Calculate the complexity distribution for the detected repeated node. The complexity are categorized as following: @@ -131,8 +156,8 @@ def get_duplicated_node_complexity_distribution( """ node_complexity_dist = [0] * 7 for node_id in duplicated_node_id_set: - complexity_score = id_complexity_map[node_id] - repeated_count = id_count_map[node_id] + complexity_score = get_complexity_score(id_node_map[node_id][0]) + repeated_count = len(id_node_map[node_id]) if complexity_score <= 10000: node_complexity_dist[0] += repeated_count elif 10000 < complexity_score <= 100000: @@ -151,7 +176,7 @@ def get_duplicated_node_complexity_distribution( return node_complexity_dist -def encode_query_id(node) -> Optional[str]: +def encode_query_id(node: "TreeNode") -> Optional[str]: """ Encode the query and its query parameter into an id using sha256. diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index b5799d816a8..67139110f80 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -700,13 +700,14 @@ def test_table(session): check_result( session, df_result, - expect_cte_optimized=True, + expect_cte_optimized=False if session.sql_simplifier_enabled else True, query_count=1, describe_count=0, union_count=1, join_count=0, ) - assert count_number_of_ctes(df_result.queries["queries"][-1]) == 1 + if not session.sql_simplifier_enabled: + assert count_number_of_ctes(df_result.queries["queries"][-1]) == 1 @pytest.mark.parametrize( @@ -1005,6 +1006,25 @@ def test_time_series_aggregation_grouping(session, enable_sql_simplifier): session.sql_simplifier_enabled = original_sql_simplifier_enabled +def test_table_select_cte(session): + table_name = random_name_for_temp_object(TempObjectType.TABLE) + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df.write.save_as_table(table_name, table_type="temp") + df = session.table(table_name) + df_result = df.with_column("add_one", col("a") + 1).union( + df.with_column("add_two", col("a") + 2) + ) + check_result( + session, + df_result, + expect_cte_optimized=False if session.sql_simplifier_enabled else True, + query_count=1, + describe_count=0, + union_count=1, + join_count=0, + ) + + @pytest.mark.skipif( IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test" )