Skip to content

Commit

Permalink
SNOW-1632701: Add uuid to SnowflakePlan (#2233)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Sep 6, 2024
1 parent 3df2d99 commit 64e433d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,17 @@ def __init__(
# In the placeholder query, subquery (child) is held by the ID of query plan
# It is used for optimization, by replacing a subquery with a CTE
self.placeholder_query = placeholder_query
# encode an id for CTE optimization
# encode an id for CTE optimization. This is generated based on the main
# query and the associated query parameters. We use this id for equality comparison
# to determine if two plans are the same.
self._id = encode_id(queries[-1].sql, queries[-1].params)
self.referenced_ctes: Set[str] = (
referenced_ctes.copy() if referenced_ctes else set()
)
self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None
# UUID for the plan to uniquely identify the SnowflakePlan object. We also use this
# to UUID track queries that are generated from the same plan.
self._uuid = str(uuid.uuid4())

def __eq__(self, other: "SnowflakePlan") -> bool:
if not isinstance(other, SnowflakePlan):
Expand All @@ -272,6 +277,10 @@ def __eq__(self, other: "SnowflakePlan") -> bool:
def __hash__(self) -> int:
return hash(self._id) if self._id else super().__hash__()

@property
def uuid(self) -> str:
return self._uuid

@property
def execution_queries(self) -> Dict["PlanQueryType", List["Query"]]:
"""
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,9 @@ def get_result_set(
action_id = plan.session._generate_new_action_id()
plan_queries = plan.execution_queries
result, result_meta = None, None
statement_params = kwargs.get("_statement_params", None) or {}
statement_params["_PLAN_UUID"] = plan.uuid
kwargs["_statement_params"] = statement_params
try:
main_queries = plan_queries[PlanQueryType.QUERIES]
placeholders = {}
Expand Down
23 changes: 23 additions & 0 deletions tests/integ/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@


import logging
from unittest.mock import patch

import pytest

from snowflake.snowpark._internal.analyzer import analyzer
from snowflake.snowpark._internal.compiler import large_query_breakdown
from snowflake.snowpark.functions import col, lit, sum_distinct, when_matched
from snowflake.snowpark.row import Row
from tests.utils import Utils

pytestmark = [
Expand Down Expand Up @@ -373,6 +375,27 @@ def test_async_job_with_large_query_breakdown(session, large_query_df):
)


def test_add_parent_plan_uuid_to_statement_params(session, large_query_df):
set_bounds(300, 600)

with patch.object(
session._conn, "run_query", wraps=session._conn.run_query
) as patched_run_query:
result = large_query_df.collect()
Utils.check_answer(result, [Row(1, 4954), Row(2, 4953)])

plan = large_query_df._plan
# 1 for current transaction, 1 for partition, 1 for main query, 1 for post action
assert patched_run_query.call_count == 4

for i, call in enumerate(patched_run_query.call_args_list):
if i == 0:
assert call.args[0] == "SELECT CURRENT_TRANSACTION()"
else:
assert "_statement_params" in call.kwargs
assert call.kwargs["_statement_params"]["_PLAN_UUID"] == plan.uuid


def test_complexity_bounds_affect_num_partitions(session, large_query_df):
"""Test complexity bounds affect number of partitions.
Also test that when partitions are added, drop table queries are added.
Expand Down

0 comments on commit 64e433d

Please sign in to comment.