From 10c612e55d89f30a6b8f945b591418d6c9756d93 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 3 Jan 2025 10:29:06 -0800 Subject: [PATCH] SNOW-1869388 add memoization to to_selectable (#2815) --- .../_internal/compiler/query_generator.py | 11 +++++++++- tests/integ/test_large_query_breakdown.py | 20 +++++++++++++++++++ .../test_replace_child_and_update_node.py | 1 + 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index c9e61e6c850..8d96fd114f8 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -68,15 +68,24 @@ def __init__( # order of when the with query block is visited. The order is important to make sure the dependency # between the CTE definition is satisfied. self.resolved_with_query_block: Dict[str, Query] = {} + # This is a memoization dict for storing the selectable for a SnowflakePlan when to_selectable + # method is called with the same SnowflakePlan. This is used to de-duplicate nodes created during + # compilation process + self._to_selectable_memo_dict = {} def to_selectable(self, plan: LogicalPlan) -> Selectable: """Given a LogicalPlan, convert it to a Selectable.""" if isinstance(plan, Selectable): return plan + plan_id = hex(id(plan)) + if plan_id in self._to_selectable_memo_dict: + return self._to_selectable_memo_dict[plan_id] + snowflake_plan = self.resolve(plan) selectable = SelectSnowflakePlan(snowflake_plan, analyzer=self) - selectable._is_valid_for_replacement = snowflake_plan._is_valid_for_replacement + selectable._is_valid_for_replacement = True + self._to_selectable_memo_dict[plan_id] = selectable return selectable def generate_queries( diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index bb5142641f3..6724f0fa5cb 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -779,6 +779,26 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): assert len(queries["post_actions"]) == 0 +def test_to_selectable_memoization(session): + session.cte_optimization_enabled = True + sql_simplifier_enabled = session.sql_simplifier_enabled + if sql_simplifier_enabled: + set_bounds(session, 300, 520) + else: + set_bounds(session, 40, 55) + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a", "b") + for i in range(7): + df = df.with_column("a", col("a") + i + col("a")) + df1 = df.select("a", "b", (col("a") + col("b")).as_("b")) + df2 = df.select("a", "b", (col("a") + col("b")).as_("c")) + df3 = df.select("a", "b", (col("a") + col("b")).as_("d")) + df5 = df1.union_all(df2).union_all(df3) + with SqlCounter(query_count=1, describe_count=0): + queries = df5.queries + assert len(queries["queries"]) == 2 + assert len(queries["post_actions"]) == 1 + + @sql_count_checker(query_count=0) def test_large_query_breakdown_enabled_parameter(session, caplog): with caplog.at_level(logging.WARNING): diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index de235a16d90..f8164b8d6fc 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -87,6 +87,7 @@ def mock_resolve(x): fake_query_generator.to_selectable = partial( QueryGenerator.to_selectable, fake_query_generator ) + fake_query_generator._to_selectable_memo_dict = {} return fake_query_generator