Skip to content

Commit

Permalink
[SNOW-1844465] Avoid creating a CTE out of simple select start on top…
Browse files Browse the repository at this point in the history
… of a select entity (#2713)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.

Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1844465

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [x] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development)

3. Please describe how your code solves the related issue.

The CTE optimization today will create a CTE out of a simple select *
from TABLE, improve teh CTE by exclude such usage.
  • Loading branch information
sfc-gh-yzou authored Dec 6, 2024
1 parent 7da2b12 commit 0c74413
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 23 deletions.
67 changes: 46 additions & 21 deletions src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,41 @@ def find_duplicate_subtrees(
This function is used to only include nodes that should be converted to CTEs.
"""
id_count_map = defaultdict(int)
id_node_map = defaultdict(list)
id_parents_map = defaultdict(set)
id_complexity_map = defaultdict(int)

from snowflake.snowpark._internal.analyzer.select_statement import (
Selectable,
SelectStatement,
SelectableEntity,
SelectSnowflakePlan,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan

def is_simple_select_entity(node: "TreeNode") -> bool:
"""
Check if the current node is a simple select on top of a SelectEntity, for example:
select * from TABLE. This check only works with selectable when sql simplifier is enabled.
"""
if isinstance(node, SelectableEntity):
return True
if (
isinstance(node, SelectStatement)
and (node.projection is None)
and isinstance(node.from_, SelectableEntity)
):
return True
if (
isinstance(node, SnowflakePlan)
and (node.source_plan is not None)
and isinstance(node.source_plan, (SnowflakePlan, Selectable))
):
return is_simple_select_entity(node.source_plan)

if isinstance(node, SelectSnowflakePlan):
return is_simple_select_entity(node.snowflake_plan)

return False

def traverse(root: "TreeNode") -> None:
"""
Expand All @@ -57,15 +89,7 @@ def traverse(root: "TreeNode") -> None:
while len(current_level) > 0:
next_level = []
for node in current_level:
id_count_map[node.encoded_node_id_with_query] += 1
if propagate_complexity_hist and (
node.encoded_node_id_with_query not in id_complexity_map
):
# if propagate_complexity_hist is true, and the complexity score is not
# recorded for the current node id, record the complexity
id_complexity_map[
node.encoded_node_id_with_query
] = get_complexity_score(node)
id_node_map[node.encoded_node_id_with_query].append(node)
for child in node.children_plan_nodes:
id_parents_map[child.encoded_node_id_with_query].add(
node.encoded_node_id_with_query
Expand All @@ -77,13 +101,15 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
# when a sql query is a select statement, its encoded_node_id_with_query
# contains _, which is used to separate the query id and node type name.
is_valid_candidate = "_" in encoded_node_id_with_query
if not is_valid_candidate:
if not is_valid_candidate or is_simple_select_entity(
id_node_map[encoded_node_id_with_query][0]
):
return False

is_duplicate_node = id_count_map[encoded_node_id_with_query] > 1
is_duplicate_node = len(id_node_map[encoded_node_id_with_query]) > 1
if is_duplicate_node:
is_any_parent_unique_node = any(
id_count_map[id] == 1
len(id_node_map[id]) == 1
for id in id_parents_map[encoded_node_id_with_query]
)
if is_any_parent_unique_node:
Expand All @@ -97,15 +123,15 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:
traverse(root)
duplicated_node_ids = {
encoded_node_id_with_query
for encoded_node_id_with_query in id_count_map
for encoded_node_id_with_query in id_node_map
if is_duplicate_subtree(encoded_node_id_with_query)
}

if propagate_complexity_hist:
return (
duplicated_node_ids,
get_duplicated_node_complexity_distribution(
duplicated_node_ids, id_complexity_map, id_count_map
duplicated_node_ids, id_node_map
),
)
else:
Expand All @@ -114,8 +140,7 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool:

def get_duplicated_node_complexity_distribution(
duplicated_node_id_set: Set[str],
id_complexity_map: Dict[str, int],
id_count_map: Dict[str, int],
id_node_map: Dict[str, List["TreeNode"]],
) -> List[int]:
"""
Calculate the complexity distribution for the detected repeated node. The complexity are categorized as following:
Expand All @@ -131,8 +156,8 @@ def get_duplicated_node_complexity_distribution(
"""
node_complexity_dist = [0] * 7
for node_id in duplicated_node_id_set:
complexity_score = id_complexity_map[node_id]
repeated_count = id_count_map[node_id]
complexity_score = get_complexity_score(id_node_map[node_id][0])
repeated_count = len(id_node_map[node_id])
if complexity_score <= 10000:
node_complexity_dist[0] += repeated_count
elif 10000 < complexity_score <= 100000:
Expand All @@ -151,7 +176,7 @@ def get_duplicated_node_complexity_distribution(
return node_complexity_dist


def encode_query_id(node) -> Optional[str]:
def encode_query_id(node: "TreeNode") -> Optional[str]:
"""
Encode the query and its query parameter into an id using sha256.
Expand Down
24 changes: 22 additions & 2 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,13 +700,14 @@ def test_table(session):
check_result(
session,
df_result,
expect_cte_optimized=True,
expect_cte_optimized=False if session.sql_simplifier_enabled else True,
query_count=1,
describe_count=0,
union_count=1,
join_count=0,
)
assert count_number_of_ctes(df_result.queries["queries"][-1]) == 1
if not session.sql_simplifier_enabled:
assert count_number_of_ctes(df_result.queries["queries"][-1]) == 1


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1005,6 +1006,25 @@ def test_time_series_aggregation_grouping(session, enable_sql_simplifier):
session.sql_simplifier_enabled = original_sql_simplifier_enabled


def test_table_select_cte(session):
table_name = random_name_for_temp_object(TempObjectType.TABLE)
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
df.write.save_as_table(table_name, table_type="temp")
df = session.table(table_name)
df_result = df.with_column("add_one", col("a") + 1).union(
df.with_column("add_two", col("a") + 2)
)
check_result(
session,
df_result,
expect_cte_optimized=False if session.sql_simplifier_enabled else True,
query_count=1,
describe_count=0,
union_count=1,
join_count=0,
)


@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="SNOW-609328: support caplog in SP regression test"
)
Expand Down

0 comments on commit 0c74413

Please sign in to comment.