Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1819531: propagate referenced ctes to all root nodes #2670

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d2c924b
debug
sfc-gh-aalam Nov 12, 2024
7b7ad95
add print
sfc-gh-aalam Nov 15, 2024
ce30000
Merge branch 'main' into aalam-fix-proj-complexity
sfc-gh-aalam Nov 21, 2024
09c9752
fix valid replacement and relaxed condition bug
sfc-gh-aalam Nov 22, 2024
3ca612a
remove print
sfc-gh-aalam Nov 22, 2024
8a48d9b
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Nov 22, 2024
f7fe2eb
fix referenced_cte bug
sfc-gh-aalam Nov 22, 2024
2fad709
code coverage
sfc-gh-aalam Nov 23, 2024
850294d
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Nov 25, 2024
8f6ca76
update test
sfc-gh-aalam Nov 25, 2024
16909dc
update test
sfc-gh-aalam Nov 25, 2024
4db6421
fix type-hint and local-test
sfc-gh-aalam Nov 26, 2024
78d52aa
remove unnecessary updates
sfc-gh-aalam Nov 27, 2024
ffcb577
remove unnecessary updates
sfc-gh-aalam Nov 27, 2024
eadcd8f
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Nov 27, 2024
28ed69b
fix _is_valid_for_replacement for resolve with_query_block
sfc-gh-aalam Dec 2, 2024
ca9be3a
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Dec 3, 2024
bfb1412
fix root cause of error
sfc-gh-aalam Dec 3, 2024
d826c21
copy value from child
sfc-gh-aalam Dec 3, 2024
2f1e1d7
address comments
sfc-gh-aalam Dec 3, 2024
5daf97e
Deprecate propagate referenced ctes
sfc-gh-aalam Dec 4, 2024
4528959
Deprecate propagate referenced ctes
sfc-gh-aalam Dec 4, 2024
8f0a23b
move to_selectable to query generator
sfc-gh-aalam Dec 4, 2024
9ff53d8
fix unit tests
sfc-gh-aalam Dec 4, 2024
fb83d0e
deprecate propagate referenced ctes
sfc-gh-aalam Dec 5, 2024
ee472b8
assert query starts correctly
sfc-gh-aalam Dec 10, 2024
1623ca6
add more test
sfc-gh-aalam Dec 10, 2024
340769e
Merge branch 'main' into aalam-SNOW-1844090-deprecate-propagate-refer…
sfc-gh-aalam Dec 10, 2024
49072b9
Merge branch 'aalam-SNOW-1844090-deprecate-propagate-referenced-ctes'…
sfc-gh-aalam Dec 10, 2024
dd6b2f9
fix
sfc-gh-aalam Dec 10, 2024
4e913f8
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,14 @@ def _is_relaxed_pipeline_breaker(self, node: LogicalPlan) -> bool:
if isinstance(node, SelectStatement):
return True

if isinstance(node, SnowflakePlan):
return node.source_plan is not None and self._is_relaxed_pipeline_breaker(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between "relaxed pipeline breaker" and "pipeline breaker"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipeline breaker are all nodes like sort, pivot etc which are listed in the is_pipeline_breaker function.
relaxed pipeline breaker are those nodes which are not pipeline breakers but can be used to cut if no valid pipeline breaker is found. For now SelectStatement is the only relaxed pipeline breaker.

node.source_plan
)

if isinstance(node, SelectSnowflakePlan):
return self._is_relaxed_pipeline_breaker(node.snowflake_plan)
sfc-gh-aalam marked this conversation as resolved.
Show resolved Hide resolved

return False

def _is_node_pipeline_breaker(self, node: LogicalPlan) -> bool:
Expand Down
15 changes: 14 additions & 1 deletion src/snowflake/snowpark/_internal/compiler/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from snowflake.snowpark._internal.analyzer.analyzer import Analyzer
from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.select_statement import Selectable
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectSnowflakePlan,
Selectable,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
PlanQueryType,
Query,
Expand Down Expand Up @@ -66,6 +69,16 @@ def __init__(
# between the CTE definition is satisfied.
self.resolved_with_query_block: Dict[str, Query] = {}

def to_selectable(self, plan: LogicalPlan) -> Selectable:
"""Given a LogicalPlan, convert it to a Selectable."""
if isinstance(plan, Selectable):
return plan

snowflake_plan = self.resolve(plan)
selectable = SelectSnowflakePlan(snowflake_plan, analyzer=self)
selectable._is_valid_for_replacement = snowflake_plan._is_valid_for_replacement
return selectable

def generate_queries(
self, logical_plans: List[LogicalPlan]
) -> Dict[PlanQueryType, List[Query]]:
Expand Down
20 changes: 8 additions & 12 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,6 @@ def replace_child(
based on the parent node type.
"""

def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selectable:
"""Given a LogicalPlan, convert it to a Selectable."""
if isinstance(plan, Selectable):
return plan

snowflake_plan = query_generator.resolve(plan)
return SelectSnowflakePlan(snowflake_plan, analyzer=query_generator)

if not parent._is_valid_for_replacement:
raise ValueError(f"parent node {parent} is not valid for replacement.")

Expand All @@ -143,13 +135,13 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta
replace_child(parent.source_plan, old_child, new_child, query_generator)

elif isinstance(parent, SelectStatement):
parent.from_ = to_selectable(new_child, query_generator)
parent.from_ = query_generator.to_selectable(new_child)
# once the subquery is updated, set _merge_projection_complexity_with_subquery to False to
# disable the projection complexity merge
parent._merge_projection_complexity_with_subquery = False

elif isinstance(parent, SetStatement):
new_child_as_selectable = to_selectable(new_child, query_generator)
new_child_as_selectable = query_generator.to_selectable(new_child)
parent._nodes = [
node if node != old_child else new_child_as_selectable
for node in parent._nodes
Expand Down Expand Up @@ -421,6 +413,9 @@ def get_name(node: Optional[LogicalPlan]) -> str:
name = f"{name} :: ({'| '.join(properties)})"

score = get_complexity_score(node)
num_ref_ctes = "nil"
if isinstance(node, (SnowflakePlan, Selectable)):
num_ref_ctes = len(node.referenced_ctes)
sql_text = ""
if isinstance(node, Selectable):
sql_text = node.sql_query
Expand All @@ -429,7 +424,7 @@ def get_name(node: Optional[LogicalPlan]) -> str:
sql_size = len(sql_text)
sql_preview = sql_text[:50]

return f"{name=}\n{score=}, {sql_size=}\n{sql_preview=}"
return f"{name=}\n{score=}, {num_ref_ctes=}, {sql_size=}\n{sql_preview=}"

g = graphviz.Graph(format="png")

Expand All @@ -439,7 +434,8 @@ def get_name(node: Optional[LogicalPlan]) -> str:
next_level = []
for node in curr_level:
node_id = hex(id(node))
g.node(node_id, get_stat(node))
color = "lightblue" if node._is_valid_for_replacement else "red"
g.node(node_id, get_stat(node), color=color)
if isinstance(node, (Selectable, SnowflakePlan)):
children = node.children_plan_nodes
else:
Expand Down
23 changes: 17 additions & 6 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,14 @@ def test_iceberg_nested_fields(
Utils.drop_table(structured_type_session, transformed_table_name)


@pytest.mark.skip(
reason="SNOW-1819531: Error in _contains_external_cte_ref when analyzing lqb"
@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="local testing does not fully support structured types yet.",
run=False,
)
@pytest.mark.parametrize("cte_enabled", [True, False])
def test_struct_dtype_iceberg_lqb(
structured_type_session, local_testing_mode, structured_type_support
structured_type_session, local_testing_mode, structured_type_support, cte_enabled
):
if not (
structured_type_support
Expand Down Expand Up @@ -641,12 +644,14 @@ def test_struct_dtype_iceberg_lqb(
is_query_compilation_stage_enabled = (
structured_type_session._query_compilation_stage_enabled
)
is_cte_optimization_enabled = structured_type_session._cte_optimization_enabled
is_large_query_breakdown_enabled = (
structured_type_session._large_query_breakdown_enabled
)
original_bounds = structured_type_session._large_query_breakdown_complexity_bounds
try:
structured_type_session._query_compilation_stage_enabled = True
structured_type_session._cte_optimization_enabled = cte_enabled
structured_type_session._large_query_breakdown_enabled = True
structured_type_session._large_query_breakdown_complexity_bounds = (300, 600)

Expand Down Expand Up @@ -695,9 +700,14 @@ def test_struct_dtype_iceberg_lqb(
)

queries = union_df.queries
# assert that the queries are broken down into 2 queries and 1 post action
assert len(queries["queries"]) == 2, queries["queries"]
assert len(queries["post_actions"]) == 1
if cte_enabled:
# when CTE is enabled, WithQueryBlock makes pipeline breaker nodes ineligible
assert len(queries["queries"]) == 1
assert len(queries["post_actions"]) == 0
else:
# assert that the queries are broken down into 2 queries and 1 post action
assert len(queries["queries"]) == 2, queries["queries"]
assert len(queries["post_actions"]) == 1
final_df = structured_type_session.table(write_table)

# assert that
Expand All @@ -707,6 +717,7 @@ def test_struct_dtype_iceberg_lqb(
structured_type_session._query_compilation_stage_enabled = (
is_query_compilation_stage_enabled
)
structured_type_session._cte_optimization_enabled = is_cte_optimization_enabled
structured_type_session._large_query_breakdown_enabled = (
is_large_query_breakdown_enabled
)
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/compiler/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,42 @@ def test_pipeline_breaker_node(mock_session, mock_analyzer, node_generator, expe
large_query_breakdown._is_node_pipeline_breaker(select_snowflake_plan)
is expected
), "SelectSnowflakePlan node is not detected as a pipeline breaker node"


@pytest.mark.parametrize(
"node_generator,expected",
[
(
lambda x: SelectStatement(
from_=empty_selectable, order_by=[empty_expression], analyzer=x
),
True,
),
],
)
def test_relaxed_pipeline_breaker_node(
mock_session, mock_analyzer, node_generator, expected
):
large_query_breakdown = LargeQueryBreakdown(
mock_session,
mock_analyzer,
[],
mock_session.large_query_breakdown_complexity_bounds,
)
node = node_generator(mock_analyzer)

assert (
large_query_breakdown._is_relaxed_pipeline_breaker(node) is expected
), f"Node {type(node)} is not detected as a pipeline breaker node"

resolved_node = mock_analyzer.resolve(node)
assert isinstance(resolved_node, SnowflakePlan)
assert (
large_query_breakdown._is_relaxed_pipeline_breaker(resolved_node) is expected
), f"Resolved node of {type(node)} is not detected as a pipeline breaker node"

select_snowflake_plan = SelectSnowflakePlan(resolved_node, analyzer=mock_analyzer)
assert (
large_query_breakdown._is_relaxed_pipeline_breaker(select_snowflake_plan)
is expected
), "SelectSnowflakePlan node is not detected as a pipeline breaker node"
5 changes: 5 additions & 0 deletions tests/unit/compiler/test_replace_child_and_update_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

import copy
from functools import partial
from unittest import mock

import pytest
Expand Down Expand Up @@ -67,6 +68,7 @@ def mock_snowflake_plan() -> SnowflakePlan:
with_query_block = WithQueryBlock(name="TEST_CTE", child=LogicalPlan())
fake_snowflake_plan.referenced_ctes = {with_query_block: 1}
fake_snowflake_plan._cumulative_node_complexity = {}
fake_snowflake_plan._is_valid_for_replacement = True
return fake_snowflake_plan


Expand All @@ -82,6 +84,9 @@ def mock_resolve(x):
fake_query_generator = mock.create_autospec(QueryGenerator)
fake_query_generator.resolve.side_effect = mock_resolve
fake_query_generator.session = mock_session
fake_query_generator.to_selectable = partial(
QueryGenerator.to_selectable, fake_query_generator
)
return fake_query_generator


Expand Down
Loading