diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 3764c61410d..9750deba3f9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -254,12 +254,17 @@ def __init__( # In the placeholder query, subquery (child) is held by the ID of query plan # It is used for optimization, by replacing a subquery with a CTE self.placeholder_query = placeholder_query - # encode an id for CTE optimization + # encode an id for CTE optimization. This is generated based on the main + # query and the associated query parameters. We use this id for equality comparison + # to determine if two plans are the same. self._id = encode_id(queries[-1].sql, queries[-1].params) self.referenced_ctes: Set[str] = ( referenced_ctes.copy() if referenced_ctes else set() ) self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None + # UUID for the plan to uniquely identify the SnowflakePlan object. We also use this + # to UUID track queries that are generated from the same plan. + self._uuid = str(uuid.uuid4()) def __eq__(self, other: "SnowflakePlan") -> bool: if not isinstance(other, SnowflakePlan): @@ -272,6 +277,10 @@ def __eq__(self, other: "SnowflakePlan") -> bool: def __hash__(self) -> int: return hash(self._id) if self._id else super().__hash__() + @property + def uuid(self) -> str: + return self._uuid + @property def execution_queries(self) -> Dict["PlanQueryType", List["Query"]]: """ diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 498a1fe406d..56445fc31b2 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -575,6 +575,9 @@ def get_result_set( action_id = plan.session._generate_new_action_id() plan_queries = plan.execution_queries result, result_meta = None, None + statement_params = kwargs.get("_statement_params", None) or {} + statement_params["_PLAN_UUID"] = plan.uuid + kwargs["_statement_params"] = statement_params try: main_queries = plan_queries[PlanQueryType.QUERIES] placeholders = {} diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index fb4d5517b98..e42a504a976 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -4,12 +4,14 @@ import logging +from unittest.mock import patch import pytest from snowflake.snowpark._internal.analyzer import analyzer from snowflake.snowpark._internal.compiler import large_query_breakdown from snowflake.snowpark.functions import col, lit, sum_distinct, when_matched +from snowflake.snowpark.row import Row from tests.utils import Utils pytestmark = [ @@ -373,6 +375,27 @@ def test_async_job_with_large_query_breakdown(session, large_query_df): ) +def test_add_parent_plan_uuid_to_statement_params(session, large_query_df): + set_bounds(300, 600) + + with patch.object( + session._conn, "run_query", wraps=session._conn.run_query + ) as patched_run_query: + result = large_query_df.collect() + Utils.check_answer(result, [Row(1, 4954), Row(2, 4953)]) + + plan = large_query_df._plan + # 1 for current transaction, 1 for partition, 1 for main query, 1 for post action + assert patched_run_query.call_count == 4 + + for i, call in enumerate(patched_run_query.call_args_list): + if i == 0: + assert call.args[0] == "SELECT CURRENT_TRANSACTION()" + else: + assert "_statement_params" in call.kwargs + assert call.kwargs["_statement_params"]["_PLAN_UUID"] == plan.uuid + + def test_complexity_bounds_affect_num_partitions(session, large_query_df): """Test complexity bounds affect number of partitions. Also test that when partitions are added, drop table queries are added.