Skip to content

Commit

Permalink
SNOW-1869388 add memoization to to_selectable (#2815)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Jan 3, 2025
1 parent e75b506 commit 10c612e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/snowflake/snowpark/_internal/compiler/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,24 @@ def __init__(
# order of when the with query block is visited. The order is important to make sure the dependency
# between the CTE definition is satisfied.
self.resolved_with_query_block: Dict[str, Query] = {}
# This is a memoization dict for storing the selectable for a SnowflakePlan when to_selectable
# method is called with the same SnowflakePlan. This is used to de-duplicate nodes created during
# compilation process
self._to_selectable_memo_dict = {}

def to_selectable(self, plan: LogicalPlan) -> Selectable:
"""Given a LogicalPlan, convert it to a Selectable."""
if isinstance(plan, Selectable):
return plan

plan_id = hex(id(plan))
if plan_id in self._to_selectable_memo_dict:
return self._to_selectable_memo_dict[plan_id]

snowflake_plan = self.resolve(plan)
selectable = SelectSnowflakePlan(snowflake_plan, analyzer=self)
selectable._is_valid_for_replacement = snowflake_plan._is_valid_for_replacement
selectable._is_valid_for_replacement = True
self._to_selectable_memo_dict[plan_id] = selectable
return selectable

def generate_queries(
Expand Down
20 changes: 20 additions & 0 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,26 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df):
assert len(queries["post_actions"]) == 0


def test_to_selectable_memoization(session):
session.cte_optimization_enabled = True
sql_simplifier_enabled = session.sql_simplifier_enabled
if sql_simplifier_enabled:
set_bounds(session, 300, 520)
else:
set_bounds(session, 40, 55)
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a", "b")
for i in range(7):
df = df.with_column("a", col("a") + i + col("a"))
df1 = df.select("a", "b", (col("a") + col("b")).as_("b"))
df2 = df.select("a", "b", (col("a") + col("b")).as_("c"))
df3 = df.select("a", "b", (col("a") + col("b")).as_("d"))
df5 = df1.union_all(df2).union_all(df3)
with SqlCounter(query_count=1, describe_count=0):
queries = df5.queries
assert len(queries["queries"]) == 2
assert len(queries["post_actions"]) == 1


@sql_count_checker(query_count=0)
def test_large_query_breakdown_enabled_parameter(session, caplog):
with caplog.at_level(logging.WARNING):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/compiler/test_replace_child_and_update_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def mock_resolve(x):
fake_query_generator.to_selectable = partial(
QueryGenerator.to_selectable, fake_query_generator
)
fake_query_generator._to_selectable_memo_dict = {}
return fake_query_generator


Expand Down

0 comments on commit 10c612e

Please sign in to comment.