Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1819531: propagate referenced ctes to all root nodes #2670

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d2c924b
debug
sfc-gh-aalam Nov 12, 2024
7b7ad95
add print
sfc-gh-aalam Nov 15, 2024
ce30000
Merge branch 'main' into aalam-fix-proj-complexity
sfc-gh-aalam Nov 21, 2024
09c9752
fix valid replacement and relaxed condition bug
sfc-gh-aalam Nov 22, 2024
3ca612a
remove print
sfc-gh-aalam Nov 22, 2024
8a48d9b
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Nov 22, 2024
f7fe2eb
fix referenced_cte bug
sfc-gh-aalam Nov 22, 2024
2fad709
code coverage
sfc-gh-aalam Nov 23, 2024
850294d
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Nov 25, 2024
8f6ca76
update test
sfc-gh-aalam Nov 25, 2024
16909dc
update test
sfc-gh-aalam Nov 25, 2024
4db6421
fix type-hint and local-test
sfc-gh-aalam Nov 26, 2024
78d52aa
remove unnecessary updates
sfc-gh-aalam Nov 27, 2024
ffcb577
remove unnecessary updates
sfc-gh-aalam Nov 27, 2024
eadcd8f
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Nov 27, 2024
28ed69b
fix _is_valid_for_replacement for resolve with_query_block
sfc-gh-aalam Dec 2, 2024
ca9be3a
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Dec 3, 2024
bfb1412
fix root cause of error
sfc-gh-aalam Dec 3, 2024
d826c21
copy value from child
sfc-gh-aalam Dec 3, 2024
2f1e1d7
address comments
sfc-gh-aalam Dec 3, 2024
5daf97e
Deprecate propagate referenced ctes
sfc-gh-aalam Dec 4, 2024
4528959
Deprecate propagate referenced ctes
sfc-gh-aalam Dec 4, 2024
8f0a23b
move to_selectable to query generator
sfc-gh-aalam Dec 4, 2024
9ff53d8
fix unit tests
sfc-gh-aalam Dec 4, 2024
fb83d0e
deprecate propagate referenced ctes
sfc-gh-aalam Dec 5, 2024
ee472b8
assert query starts correctly
sfc-gh-aalam Dec 10, 2024
1623ca6
add more test
sfc-gh-aalam Dec 10, 2024
340769e
Merge branch 'main' into aalam-SNOW-1844090-deprecate-propagate-refer…
sfc-gh-aalam Dec 10, 2024
49072b9
Merge branch 'aalam-SNOW-1844090-deprecate-propagate-referenced-ctes'…
sfc-gh-aalam Dec 10, 2024
dd6b2f9
fix
sfc-gh-aalam Dec 10, 2024
4e913f8
Merge branch 'main' into aalam-SNOW-1819531-lqb-bug-fixes
sfc-gh-aalam Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,6 @@ def build(
source_plan: Optional[LogicalPlan],
schema_query: Optional[str] = None,
is_ddl_on_temp_object: bool = False,
# Whether propagate the referenced ctes from child to the new plan built.
# In general, the referenced should be propagated from child, but for cases
# like SnowflakeCreateTable, the CTEs should not be propagated, because
# the CTEs are already embedded and consumed in the child.
propagate_referenced_ctes: bool = True,
) -> SnowflakePlan:
select_child = self.add_result_scan_if_not_select(child)
queries = select_child.queries[:-1] + [
Expand Down Expand Up @@ -565,9 +560,7 @@ def build(
api_calls=select_child.api_calls,
df_aliased_col_name_to_real_col_name=child.df_aliased_col_name_to_real_col_name,
session=self.session,
referenced_ctes=child.referenced_ctes
sfc-gh-aalam marked this conversation as resolved.
Show resolved Hide resolved
if propagate_referenced_ctes
else None,
referenced_ctes=child.referenced_ctes,
)

@SnowflakePlan.Decorator.wrap_exception
Expand Down Expand Up @@ -941,7 +934,6 @@ def get_create_table_as_select_plan(child: SnowflakePlan, replace, error):
child,
source_plan,
is_ddl_on_temp_object=is_temp_table_type,
propagate_referenced_ctes=False,
sfc-gh-aalam marked this conversation as resolved.
Show resolved Hide resolved
)

def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
Expand Down Expand Up @@ -1002,7 +994,6 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
),
child,
source_plan,
propagate_referenced_ctes=False,
)
else:
return get_create_and_insert_plan(child, replace=False, error=False)
Expand All @@ -1016,7 +1007,6 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
),
child,
source_plan,
propagate_referenced_ctes=False,
)
else:
return get_create_table_as_select_plan(child, replace=True, error=True)
Expand Down Expand Up @@ -1113,7 +1103,6 @@ def create_or_replace_view(
lambda x: create_or_replace_view_statement(name, x, is_temp, comment),
child,
source_plan,
propagate_referenced_ctes=False,
)

def create_or_replace_dynamic_table(
Expand Down Expand Up @@ -1480,7 +1469,6 @@ def copy_into_location(
query,
source_plan,
query.schema_query,
propagate_referenced_ctes=False,
)

def update(
Expand All @@ -1501,7 +1489,6 @@ def update(
),
source_data,
source_plan,
propagate_referenced_ctes=False,
)
else:
return self.query(
Expand Down Expand Up @@ -1530,7 +1517,6 @@ def delete(
),
source_data,
source_plan,
propagate_referenced_ctes=False,
)
else:
return self.query(
Expand All @@ -1554,7 +1540,6 @@ def merge(
lambda x: merge_statement(table_name, x, join_expr, clauses),
source_data,
source_plan,
propagate_referenced_ctes=False,
)

def lateral(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,14 @@ def _is_relaxed_pipeline_breaker(self, node: LogicalPlan) -> bool:
if isinstance(node, SelectStatement):
return True

if isinstance(node, SnowflakePlan):
return node.source_plan is not None and self._is_relaxed_pipeline_breaker(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between "relaxed pipeline breaker" and "pipeline breaker"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipeline breaker are all nodes like sort, pivot etc which are listed in the is_pipeline_breaker function.
relaxed pipeline breaker are those nodes which are not pipeline breakers but can be used to cut if no valid pipeline breaker is found. For now SelectStatement is the only relaxed pipeline breaker.

node.source_plan
)

if isinstance(node, SelectSnowflakePlan):
return self._is_relaxed_pipeline_breaker(node.snowflake_plan)
sfc-gh-aalam marked this conversation as resolved.
Show resolved Hide resolved

return False

def _is_node_pipeline_breaker(self, node: LogicalPlan) -> bool:
Expand Down
17 changes: 16 additions & 1 deletion src/snowflake/snowpark/_internal/compiler/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from snowflake.snowpark._internal.analyzer.analyzer import Analyzer
from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.select_statement import Selectable
from snowflake.snowpark._internal.analyzer.select_statement import (
SelectSnowflakePlan,
Selectable,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
PlanQueryType,
Query,
Expand Down Expand Up @@ -66,6 +69,16 @@ def __init__(
# between the CTE definition is satisfied.
self.resolved_with_query_block: Dict[str, Query] = {}

def to_selectable(self, plan: LogicalPlan) -> Selectable:
"""Given a LogicalPlan, convert it to a Selectable."""
if isinstance(plan, Selectable):
return plan

snowflake_plan = self.resolve(plan)
selectable = SelectSnowflakePlan(snowflake_plan, analyzer=self)
selectable._is_valid_for_replacement = snowflake_plan._is_valid_for_replacement
return selectable

def generate_queries(
self, logical_plans: List[LogicalPlan]
) -> Dict[PlanQueryType, List[Query]]:
Expand Down Expand Up @@ -169,6 +182,7 @@ def do_resolve_with_resolved_children(
iceberg_config=logical_plan.iceberg_config,
table_exists=logical_plan.table_exists,
)
resolved_plan.referenced_ctes = resolved_child.referenced_ctes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems we missed the cte reference propagation in the plan builder for some case, we should double check the PlanBuilder code in snowflake_plan, instead of doing the fix here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some thoughts, let's update the referenced_ctes definition to the cte referenced in the plan tree, and let's propagate the ctes reference for all nodes, in other words, let's deprecate this parameter https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py#L538, and make sure the referenced ctes is propagated correctly for all plan builder, please double check all places that directly create SnowflakePlan to make sure the referenced_ctes is propagated correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the correction of code generation, we can use the way you have now, but please make sure we comment it out clearly about why things is done in such way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cte reference is propagated correctly during plan builder, you shouldn't need to re-propagate here.

sfc-gh-aalam marked this conversation as resolved.
Show resolved Hide resolved

elif isinstance(
logical_plan,
Expand Down Expand Up @@ -197,6 +211,7 @@ def do_resolve_with_resolved_children(
resolved_plan = super().do_resolve_with_resolved_children(
logical_plan, resolved_children, df_aliased_col_name_to_real_col_name
)
resolved_plan.referenced_ctes = resolved_child.referenced_ctes

elif isinstance(logical_plan, Selectable):
# overwrite the Selectable resolving to make sure we are triggering
Expand Down
48 changes: 34 additions & 14 deletions src/snowflake/snowpark/_internal/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
TableMerge,
TableUpdate,
)
from snowflake.snowpark._internal.analyzer.unary_plan_node import UnaryNode
from snowflake.snowpark._internal.analyzer.unary_plan_node import (
CreateViewCommand,
UnaryNode,
)
from snowflake.snowpark._internal.compiler.query_generator import (
QueryGenerator,
SnowflakeCreateTablePlanInfo,
Expand Down Expand Up @@ -115,14 +118,6 @@ def replace_child(
based on the parent node type.
"""

def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selectable:
"""Given a LogicalPlan, convert it to a Selectable."""
if isinstance(plan, Selectable):
return plan

snowflake_plan = query_generator.resolve(plan)
return SelectSnowflakePlan(snowflake_plan, analyzer=query_generator)

if not parent._is_valid_for_replacement:
raise ValueError(f"parent node {parent} is not valid for replacement.")

Expand All @@ -140,13 +135,13 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta
replace_child(parent.source_plan, old_child, new_child, query_generator)

elif isinstance(parent, SelectStatement):
parent.from_ = to_selectable(new_child, query_generator)
parent.from_ = query_generator.to_selectable(new_child)
# once the subquery is updated, set _merge_projection_complexity_with_subquery to False to
# disable the projection complexity merge
parent._merge_projection_complexity_with_subquery = False

elif isinstance(parent, SetStatement):
new_child_as_selectable = to_selectable(new_child, query_generator)
new_child_as_selectable = query_generator.to_selectable(new_child)
parent._nodes = [
node if node != old_child else new_child_as_selectable
for node in parent._nodes
Expand Down Expand Up @@ -306,7 +301,28 @@ def get_snowflake_plan_queries(

plan_queries = plan.queries
post_action_queries = plan.post_actions
if len(plan.referenced_ctes) > 0:
# If the plan has referenced ctes, we need to add the cte definition before
# the final query. This is done for all source plan except for the following
# cases:
# - SnowflakeCreateTable
# - CreateViewCommand
# - TableUpdate
# - TableDelete
# - TableMerge
# - CopyIntoLocationNode
# because the generated_queries by QueryGenerator for these nodes already include the cte
# definition. Adding the cte definition before the query again will cause a syntax error.
if len(plan.referenced_ctes) > 0 and not isinstance(
plan.source_plan,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adda a comment here about why are we doing this here

(
SnowflakeCreateTable,
CreateViewCommand,
TableUpdate,
TableDelete,
TableMerge,
CopyIntoLocationNode,
),
):
# make a copy of the original query to avoid any update to the
# original query object
plan_queries = copy.deepcopy(plan.queries)
Expand Down Expand Up @@ -397,6 +413,9 @@ def get_name(node: Optional[LogicalPlan]) -> str:
name = f"{name} :: ({'| '.join(properties)})"

score = get_complexity_score(node)
num_ref_ctes = "nil"
if isinstance(node, (SnowflakePlan, Selectable)):
num_ref_ctes = len(node.referenced_ctes)
sql_text = ""
if isinstance(node, Selectable):
sql_text = node.sql_query
Expand All @@ -405,7 +424,7 @@ def get_name(node: Optional[LogicalPlan]) -> str:
sql_size = len(sql_text)
sql_preview = sql_text[:50]

return f"{name=}\n{score=}, {sql_size=}\n{sql_preview=}"
return f"{name=}\n{score=}, {num_ref_ctes=}, {sql_size=}\n{sql_preview=}"

g = graphviz.Graph(format="png")

Expand All @@ -415,7 +434,8 @@ def get_name(node: Optional[LogicalPlan]) -> str:
next_level = []
for node in curr_level:
node_id = hex(id(node))
g.node(node_id, get_stat(node))
color = "lightblue" if node._is_valid_for_replacement else "red"
g.node(node_id, get_stat(node), color=color)
if isinstance(node, (Selectable, SnowflakePlan)):
children = node.children_plan_nodes
else:
Expand Down
23 changes: 17 additions & 6 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,14 @@ def test_iceberg_nested_fields(
Utils.drop_table(structured_type_session, transformed_table_name)


@pytest.mark.skip(
reason="SNOW-1819531: Error in _contains_external_cte_ref when analyzing lqb"
@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="local testing does not fully support structured types yet.",
run=False,
)
@pytest.mark.parametrize("cte_enabled", [True, False])
def test_struct_dtype_iceberg_lqb(
structured_type_session, local_testing_mode, structured_type_support
structured_type_session, local_testing_mode, structured_type_support, cte_enabled
):
if not (
structured_type_support
Expand Down Expand Up @@ -641,12 +644,14 @@ def test_struct_dtype_iceberg_lqb(
is_query_compilation_stage_enabled = (
structured_type_session._query_compilation_stage_enabled
)
is_cte_optimization_enabled = structured_type_session._cte_optimization_enabled
is_large_query_breakdown_enabled = (
structured_type_session._large_query_breakdown_enabled
)
original_bounds = structured_type_session._large_query_breakdown_complexity_bounds
try:
structured_type_session._query_compilation_stage_enabled = True
structured_type_session._cte_optimization_enabled = cte_enabled
structured_type_session._large_query_breakdown_enabled = True
structured_type_session._large_query_breakdown_complexity_bounds = (300, 600)

Expand Down Expand Up @@ -695,9 +700,14 @@ def test_struct_dtype_iceberg_lqb(
)

queries = union_df.queries
# assert that the queries are broken down into 2 queries and 1 post action
assert len(queries["queries"]) == 2, queries["queries"]
assert len(queries["post_actions"]) == 1
if cte_enabled:
# when CTE is enabled, WithQueryBlock makes pipeline breaker nodes ineligible
assert len(queries["queries"]) == 1
assert len(queries["post_actions"]) == 0
else:
# assert that the queries are broken down into 2 queries and 1 post action
assert len(queries["queries"]) == 2, queries["queries"]
assert len(queries["post_actions"]) == 1
final_df = structured_type_session.table(write_table)

# assert that
Expand All @@ -707,6 +717,7 @@ def test_struct_dtype_iceberg_lqb(
structured_type_session._query_compilation_stage_enabled = (
is_query_compilation_stage_enabled
)
structured_type_session._cte_optimization_enabled = is_cte_optimization_enabled
structured_type_session._large_query_breakdown_enabled = (
is_large_query_breakdown_enabled
)
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/compiler/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,42 @@ def test_pipeline_breaker_node(mock_session, mock_analyzer, node_generator, expe
large_query_breakdown._is_node_pipeline_breaker(select_snowflake_plan)
is expected
), "SelectSnowflakePlan node is not detected as a pipeline breaker node"


@pytest.mark.parametrize(
"node_generator,expected",
[
(
lambda x: SelectStatement(
from_=empty_selectable, order_by=[empty_expression], analyzer=x
),
True,
),
],
)
def test_relaxed_pipeline_breaker_node(
mock_session, mock_analyzer, node_generator, expected
):
large_query_breakdown = LargeQueryBreakdown(
mock_session,
mock_analyzer,
[],
mock_session.large_query_breakdown_complexity_bounds,
)
node = node_generator(mock_analyzer)

assert (
large_query_breakdown._is_relaxed_pipeline_breaker(node) is expected
), f"Node {type(node)} is not detected as a pipeline breaker node"

resolved_node = mock_analyzer.resolve(node)
assert isinstance(resolved_node, SnowflakePlan)
assert (
large_query_breakdown._is_relaxed_pipeline_breaker(resolved_node) is expected
), f"Resolved node of {type(node)} is not detected as a pipeline breaker node"

select_snowflake_plan = SelectSnowflakePlan(resolved_node, analyzer=mock_analyzer)
assert (
large_query_breakdown._is_relaxed_pipeline_breaker(select_snowflake_plan)
is expected
), "SelectSnowflakePlan node is not detected as a pipeline breaker node"
5 changes: 5 additions & 0 deletions tests/unit/compiler/test_replace_child_and_update_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

import copy
from functools import partial
from unittest import mock

import pytest
Expand Down Expand Up @@ -67,6 +68,7 @@ def mock_snowflake_plan() -> SnowflakePlan:
with_query_block = WithQueryBlock(name="TEST_CTE", child=LogicalPlan())
fake_snowflake_plan.referenced_ctes = {with_query_block: 1}
fake_snowflake_plan._cumulative_node_complexity = {}
fake_snowflake_plan._is_valid_for_replacement = True
return fake_snowflake_plan


Expand All @@ -82,6 +84,9 @@ def mock_resolve(x):
fake_query_generator = mock.create_autospec(QueryGenerator)
fake_query_generator.resolve.side_effect = mock_resolve
fake_query_generator.session = mock_session
fake_query_generator.to_selectable = partial(
QueryGenerator.to_selectable, fake_query_generator
)
return fake_query_generator


Expand Down
Loading