Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1546396: Remove sortedcontainers dependency #2198

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from collections import defaultdict
from typing import List, Optional, Tuple

from sortedcontainers import SortedList

from snowflake.snowpark._internal.analyzer.analyzer_utils import (
drop_table_if_exists_statement,
)
Expand Down Expand Up @@ -201,11 +199,11 @@ def _find_node_to_breakdown(self, root: TreeNode) -> Optional[TreeNode]:

1. Traverse the plan tree and find the valid nodes for partitioning.
2. If no valid node is found, return None.
3. Keep valid nodes in a sorted list based on the complexity score.
4. Return the node with the highest complexity score.
3. Return the node with the highest complexity score.
"""
current_level = [root]
pipeline_breaker_list = SortedList(key=lambda x: x[0])
candidate_node = None
candidate_score = -1 # start with -1 since score is always > 0

while current_level:
next_level = []
Expand All @@ -215,23 +213,20 @@ def _find_node_to_breakdown(self, root: TreeNode) -> Optional[TreeNode]:
self._parent_map[child].add(node)
valid_to_breakdown, score = self._is_node_valid_to_breakdown(child)
if valid_to_breakdown:
# Append score and child to the pipeline breaker sorted list
# so that the valid child with the highest complexity score
# is at the end of the list.
pipeline_breaker_list.add((score, child))
# If the score for valid node is higher than the last candidate,
# update the candidate node and score.
if score > candidate_score:
candidate_score = score
candidate_node = child
else:
# don't traverse subtrees if parent is a valid candidate
next_level.append(child)

current_level = next_level

if not pipeline_breaker_list:
# Return None if no valid node is found for partitioning.
return None

# Get the node with the highest complexity score
_, child = pipeline_breaker_list.pop()
return child
# If no valid node is found, candidate_node will be None.
# Otherwise, return the node with the highest complexity score.
return candidate_node

def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePlan:
"""This method takes cuts the child out from the root, creates a temp table plan for the
Expand Down
34 changes: 23 additions & 11 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def setup(session):
cte_optimization_enabled = session._cte_optimization_enabled
is_query_compilation_stage_enabled = session._query_compilation_stage_enabled
session._query_compilation_stage_enabled = True
session._large_query_breakdown_enabled = True
yield
session._query_compilation_stage_enabled = is_query_compilation_stage_enabled
session._cte_optimization_enabled = cte_optimization_enabled
Expand Down Expand Up @@ -77,11 +78,32 @@ def check_result_with_and_without_breakdown(session, df):
session._large_query_breakdown_enabled = large_query_enabled


def test_no_valid_nodes_found(session, large_query_df, caplog):
"""Test large query breakdown works with default bounds"""
set_bounds(300, 600)

base_df = session.sql("select 1 as A, 2 as B")
df1 = base_df.with_column("A", col("A") + lit(1))
df2 = base_df.with_column("B", col("B") + lit(1))

for i in range(102):
df1 = df1.with_column("A", col("A") + lit(i))
df2 = df2.with_column("B", col("B") + lit(i))

union_df = df1.union_all(df2)
final_df = union_df.with_column("A", col("A") + lit(1))

with caplog.at_level(logging.DEBUG):
queries = final_df.queries
assert len(queries["queries"]) == 1, queries["queries"]
assert len(queries["post_actions"]) == 0, queries["post_actions"]
assert "Could not find a valid node for partitioning" in caplog.text


def test_large_query_breakdown_with_cte_optimization(session):
"""Test large query breakdown works with cte optimized plan"""
set_bounds(300, 600)
session._cte_optimization_enabled = True
session._large_query_breakdown_enabled = True
df0 = session.sql("select 2 as b, 32 as c")
df1 = session.sql("select 1 as a, 2 as b").filter(col("a") == 1)
df1 = df1.join(df0, on=["b"], how="inner")
Expand All @@ -108,7 +130,6 @@ def test_large_query_breakdown_with_cte_optimization(session):

def test_save_as_table(session, large_query_df):
set_bounds(300, 600)
session._large_query_breakdown_enabled = True
table_name = Utils.random_table_name()
with session.query_history() as history:
large_query_df.write.save_as_table(table_name, mode="overwrite")
Expand Down Expand Up @@ -164,7 +185,6 @@ def test_update_delete_merge(session, large_query_df):

def test_copy_into_location(session, large_query_df):
set_bounds(300, 600)
session._large_query_breakdown_enabled = True
remote_file_path = f"{session.get_session_stage()}/df.parquet"
with session.query_history() as history:
large_query_df.write.copy_into_location(
Expand All @@ -183,7 +203,6 @@ def test_copy_into_location(session, large_query_df):

def test_pivot_unpivot(session):
set_bounds(300, 600)
session._large_query_breakdown_enabled = True
session.sql(
"""create or replace temp table monthly_sales(A int, B int, month text)
as select * from values
Expand Down Expand Up @@ -223,7 +242,6 @@ def test_pivot_unpivot(session):

def test_sort(session):
set_bounds(300, 600)
session._large_query_breakdown_enabled = True
base_df = session.sql("select 1 as A, 2 as B")
df1 = base_df.with_column("A", col("A") + lit(1))
df2 = base_df.with_column("B", col("B") + lit(1))
Expand Down Expand Up @@ -258,7 +276,6 @@ def test_sort(session):
def test_multiple_query_plan(session, large_query_df):
set_bounds(300, 600)
original_threshold = analyzer.ARRAY_BIND_THRESHOLD
session._large_query_breakdown_enabled = True
try:
analyzer.ARRAY_BIND_THRESHOLD = 2
base_df = session.create_dataframe([[1, 2], [3, 4]], schema=["A", "B"])
Expand Down Expand Up @@ -296,7 +313,6 @@ def test_multiple_query_plan(session, large_query_df):
def test_optimization_skipped_with_transaction(session, large_query_df, caplog):
"""Test large query breakdown is skipped when transaction is enabled"""
set_bounds(300, 600)
session._large_query_breakdown_enabled = True
session.sql("begin").collect()
assert Utils.is_active_transaction(session)
with caplog.at_level(logging.DEBUG):
Expand All @@ -316,7 +332,6 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog):
source_table = Utils.random_table_name()
table_name = Utils.random_table_name()
view_name = Utils.random_view_name()
session._large_query_breakdown_enabled = True
try:
session.sql("select 1 as a, 2 as b").write.save_as_table(source_table)
df = session.table(source_table)
Expand Down Expand Up @@ -344,7 +359,6 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog):
def test_async_job_with_large_query_breakdown(session, large_query_df):
"""Test large query breakdown gives same result for async and non-async jobs"""
set_bounds(300, 600)
session._large_query_breakdown_enabled = True
job = large_query_df.collect(block=False)
result = job.result()
assert result == large_query_df.collect()
Expand All @@ -362,7 +376,6 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df):
Also test that when partitions are added, drop table queries are added.
"""
set_bounds(300, 600)
session._large_query_breakdown_enabled = True
assert len(large_query_df.queries["queries"]) == 2
assert len(large_query_df.queries["post_actions"]) == 1
assert large_query_df.queries["queries"][0].startswith("CREATE TEMP TABLE")
Expand All @@ -371,7 +384,6 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df):
)

set_bounds(300, 412)
session._large_query_breakdown_enabled = True
assert len(large_query_df.queries["queries"]) == 3
assert len(large_query_df.queries["post_actions"]) == 2
assert large_query_df.queries["queries"][0].startswith("CREATE TEMP TABLE")
Expand Down
Loading