diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py index 34d27862ced..5707d71dc33 100644 --- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py +++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py @@ -6,8 +6,6 @@ from collections import defaultdict from typing import List, Optional, Tuple -from sortedcontainers import SortedList - from snowflake.snowpark._internal.analyzer.analyzer_utils import ( drop_table_if_exists_statement, ) @@ -201,11 +199,11 @@ def _find_node_to_breakdown(self, root: TreeNode) -> Optional[TreeNode]: 1. Traverse the plan tree and find the valid nodes for partitioning. 2. If no valid node is found, return None. - 3. Keep valid nodes in a sorted list based on the complexity score. - 4. Return the node with the highest complexity score. + 3. Return the node with the highest complexity score. """ current_level = [root] - pipeline_breaker_list = SortedList(key=lambda x: x[0]) + candidate_node = None + candidate_score = -1 # start with -1 since score is always > 0 while current_level: next_level = [] @@ -215,23 +213,20 @@ def _find_node_to_breakdown(self, root: TreeNode) -> Optional[TreeNode]: self._parent_map[child].add(node) valid_to_breakdown, score = self._is_node_valid_to_breakdown(child) if valid_to_breakdown: - # Append score and child to the pipeline breaker sorted list - # so that the valid child with the highest complexity score - # is at the end of the list. - pipeline_breaker_list.add((score, child)) + # If the score for valid node is higher than the last candidate, + # update the candidate node and score. + if score > candidate_score: + candidate_score = score + candidate_node = child else: # don't traverse subtrees if parent is a valid candidate next_level.append(child) current_level = next_level - if not pipeline_breaker_list: - # Return None if no valid node is found for partitioning. - return None - - # Get the node with the highest complexity score - _, child = pipeline_breaker_list.pop() - return child + # If no valid node is found, candidate_node will be None. + # Otherwise, return the node with the highest complexity score. + return candidate_node def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePlan: """This method takes cuts the child out from the root, creates a temp table plan for the diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index 1368bf460f2..72ade31d456 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -47,6 +47,7 @@ def setup(session): cte_optimization_enabled = session._cte_optimization_enabled is_query_compilation_stage_enabled = session._query_compilation_stage_enabled session._query_compilation_stage_enabled = True + session._large_query_breakdown_enabled = True yield session._query_compilation_stage_enabled = is_query_compilation_stage_enabled session._cte_optimization_enabled = cte_optimization_enabled @@ -77,11 +78,32 @@ def check_result_with_and_without_breakdown(session, df): session._large_query_breakdown_enabled = large_query_enabled +def test_no_valid_nodes_found(session, large_query_df, caplog): + """Test large query breakdown works with default bounds""" + set_bounds(300, 600) + + base_df = session.sql("select 1 as A, 2 as B") + df1 = base_df.with_column("A", col("A") + lit(1)) + df2 = base_df.with_column("B", col("B") + lit(1)) + + for i in range(102): + df1 = df1.with_column("A", col("A") + lit(i)) + df2 = df2.with_column("B", col("B") + lit(i)) + + union_df = df1.union_all(df2) + final_df = union_df.with_column("A", col("A") + lit(1)) + + with caplog.at_level(logging.DEBUG): + queries = final_df.queries + assert len(queries["queries"]) == 1, queries["queries"] + assert len(queries["post_actions"]) == 0, queries["post_actions"] + assert "Could not find a valid node for partitioning" in caplog.text + + def test_large_query_breakdown_with_cte_optimization(session): """Test large query breakdown works with cte optimized plan""" set_bounds(300, 600) session._cte_optimization_enabled = True - session._large_query_breakdown_enabled = True df0 = session.sql("select 2 as b, 32 as c") df1 = session.sql("select 1 as a, 2 as b").filter(col("a") == 1) df1 = df1.join(df0, on=["b"], how="inner") @@ -108,7 +130,6 @@ def test_large_query_breakdown_with_cte_optimization(session): def test_save_as_table(session, large_query_df): set_bounds(300, 600) - session._large_query_breakdown_enabled = True table_name = Utils.random_table_name() with session.query_history() as history: large_query_df.write.save_as_table(table_name, mode="overwrite") @@ -164,7 +185,6 @@ def test_update_delete_merge(session, large_query_df): def test_copy_into_location(session, large_query_df): set_bounds(300, 600) - session._large_query_breakdown_enabled = True remote_file_path = f"{session.get_session_stage()}/df.parquet" with session.query_history() as history: large_query_df.write.copy_into_location( @@ -183,7 +203,6 @@ def test_copy_into_location(session, large_query_df): def test_pivot_unpivot(session): set_bounds(300, 600) - session._large_query_breakdown_enabled = True session.sql( """create or replace temp table monthly_sales(A int, B int, month text) as select * from values @@ -223,7 +242,6 @@ def test_pivot_unpivot(session): def test_sort(session): set_bounds(300, 600) - session._large_query_breakdown_enabled = True base_df = session.sql("select 1 as A, 2 as B") df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) @@ -258,7 +276,6 @@ def test_sort(session): def test_multiple_query_plan(session, large_query_df): set_bounds(300, 600) original_threshold = analyzer.ARRAY_BIND_THRESHOLD - session._large_query_breakdown_enabled = True try: analyzer.ARRAY_BIND_THRESHOLD = 2 base_df = session.create_dataframe([[1, 2], [3, 4]], schema=["A", "B"]) @@ -296,7 +313,6 @@ def test_multiple_query_plan(session, large_query_df): def test_optimization_skipped_with_transaction(session, large_query_df, caplog): """Test large query breakdown is skipped when transaction is enabled""" set_bounds(300, 600) - session._large_query_breakdown_enabled = True session.sql("begin").collect() assert Utils.is_active_transaction(session) with caplog.at_level(logging.DEBUG): @@ -316,7 +332,6 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog): source_table = Utils.random_table_name() table_name = Utils.random_table_name() view_name = Utils.random_view_name() - session._large_query_breakdown_enabled = True try: session.sql("select 1 as a, 2 as b").write.save_as_table(source_table) df = session.table(source_table) @@ -344,7 +359,6 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog): def test_async_job_with_large_query_breakdown(session, large_query_df): """Test large query breakdown gives same result for async and non-async jobs""" set_bounds(300, 600) - session._large_query_breakdown_enabled = True job = large_query_df.collect(block=False) result = job.result() assert result == large_query_df.collect() @@ -362,7 +376,6 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): Also test that when partitions are added, drop table queries are added. """ set_bounds(300, 600) - session._large_query_breakdown_enabled = True assert len(large_query_df.queries["queries"]) == 2 assert len(large_query_df.queries["post_actions"]) == 1 assert large_query_df.queries["queries"][0].startswith("CREATE TEMP TABLE") @@ -371,7 +384,6 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): ) set_bounds(300, 412) - session._large_query_breakdown_enabled = True assert len(large_query_df.queries["queries"]) == 3 assert len(large_query_df.queries["post_actions"]) == 2 assert large_query_df.queries["queries"][0].startswith("CREATE TEMP TABLE")