From 56fb566a9fa89184e705263bde96c18d697697b4 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 11:44:26 -0700 Subject: [PATCH 01/62] init --- src/snowflake/snowpark/_internal/udf_utils.py | 16 ++++++++++------ src/snowflake/snowpark/session.py | 8 ++++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index b79fcdcf9c9..b4be8cca9c3 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -980,8 +980,11 @@ def add_snowpark_package_to_sproc_packages( if packages is None: if session is None: packages = [this_package] - elif package_name not in session._packages: - packages = list(session._packages.values()) + [this_package] + else: + with session._lock: + session_packages = session._packages.copy() + if package_name not in session_packages: + packages = list(session_packages.values()) + [this_package] else: package_names = [p if isinstance(p, str) else p.__name__ for p in packages] if not any(p.startswith(package_name) for p in package_names): @@ -1247,10 +1250,11 @@ def create_python_udf_or_sp( comment: Optional[str] = None, native_app_params: Optional[Dict[str, Any]] = None, ) -> None: - if session is not None and session._runtime_version_from_requirement: - runtime_version = session._runtime_version_from_requirement - else: - runtime_version = f"{sys.version_info[0]}.{sys.version_info[1]}" + with session._lock: + if session is not None and session._runtime_version_from_requirement: + runtime_version = session._runtime_version_from_requirement + else: + runtime_version = f"{sys.version_info[0]}.{sys.version_info[1]}" if replace and if_not_exists: raise ValueError("options replace and if_not_exists are incompatible") diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a04e381e985..5e6ad560592 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -13,6 +13,7 @@ import re import sys import tempfile +import threading import warnings from array import array from functools import reduce @@ -499,6 +500,8 @@ def __init__( if len(_active_sessions) >= 1 and is_in_stored_procedure(): raise SnowparkClientExceptionMessages.DONT_CREATE_SESSION_IN_SP() self._conn = conn + self._thread_store = threading.local() + self._lock = threading.RLock() self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} self._packages: Dict[str, str] = {} @@ -3007,8 +3010,9 @@ def get_fully_qualified_name_if_possible(self, name: str) -> str: """ Returns the fully qualified object name if current database/schema exists, otherwise returns the object name """ - database = self.get_current_database() - schema = self.get_current_schema() + with self._lock: + database = self.get_current_database() + schema = self.get_current_schema() if database and schema: return f"{database}.{schema}.{name}" From 66003d19e785d363d2a001362bacde23ab08a343 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 13:28:45 -0700 Subject: [PATCH 02/62] make udf/sproc related files thread-safe --- src/snowflake/snowpark/_internal/udf_utils.py | 6 +- src/snowflake/snowpark/session.py | 81 +++++++++++-------- 2 files changed, 49 insertions(+), 38 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index b4be8cca9c3..4012ad9ff6e 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1089,7 +1089,7 @@ def resolve_imports_and_packages( else session.get_session_stage(statement_params=statement_params) ) - if session: + all_urls = [] if imports: udf_level_imports = {} for udf_import in imports: @@ -1117,10 +1117,6 @@ def resolve_imports_and_packages( upload_and_import_stage, statement_params=statement_params, ) - else: - all_urls = [] - else: - all_urls = [] dest_prefix = get_udf_upload_prefix(udf_name) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 5e6ad560592..af8c06ee3a6 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -809,7 +809,8 @@ def get_imports(self) -> List[str]: Returns a list of imports added for user defined functions (UDFs). This list includes any Python or zip files that were added automatically by the library. """ - return list(self._import_paths.keys()) + with self._lock: + return list(self._import_paths.keys()) def add_import( self, @@ -890,7 +891,8 @@ def add_import( path, checksum, leading_path = self._resolve_import_path( path, import_path, chunk_size, whole_file_hash ) - self._import_paths[path] = (checksum, leading_path) + with self._lock: + self._import_paths[path] = (checksum, leading_path) def remove_import(self, path: str) -> None: """ @@ -917,10 +919,11 @@ def remove_import(self, path: str) -> None: if not trimmed_path.startswith(STAGE_PREFIX) else trimmed_path ) - if abs_path not in self._import_paths: - raise KeyError(f"{abs_path} is not found in the existing imports") - else: - self._import_paths.pop(abs_path) + with self._lock: + if abs_path not in self._import_paths: + raise KeyError(f"{abs_path} is not found in the existing imports") + else: + self._import_paths.pop(abs_path) def clear_imports(self) -> None: """ @@ -929,7 +932,8 @@ def clear_imports(self) -> None: if isinstance(self._conn, MockServerConnection): self.udf._clear_session_imports() self.sproc._clear_session_imports() - self._import_paths.clear() + with self._lock: + self._import_paths.clear() def _resolve_import_path( self, @@ -1020,7 +1024,8 @@ def _resolve_imports( upload_and_import_stage ) - import_paths = udf_level_import_paths or self._import_paths + with self._lock: + import_paths = udf_level_import_paths or self._import_paths.copy() for path, (prefix, leading_path) in import_paths.items(): # stage file if path.startswith(STAGE_PREFIX): @@ -1102,7 +1107,8 @@ def get_packages(self) -> Dict[str, str]: The key of this ``dict`` is the package name and the value of this ``dict`` is the corresponding requirement specifier. """ - return self._packages.copy() + with self._lock: + return self._packages.copy() def add_packages( self, *packages: Union[str, ModuleType, Iterable[Union[str, ModuleType]]] @@ -1193,16 +1199,18 @@ def remove_package(self, package: str) -> None: 0 """ package_name = pkg_resources.Requirement.parse(package).key - if package_name in self._packages: - self._packages.pop(package_name) - else: - raise ValueError(f"{package_name} is not in the package list") + with self._lock: + if package_name in self._packages: + self._packages.pop(package_name) + else: + raise ValueError(f"{package_name} is not in the package list") def clear_packages(self) -> None: """ Clears all third-party packages of a user-defined function (UDF). """ - self._packages.clear() + with self._lock: + self._packages.clear() def add_requirements(self, file_path: str) -> None: """ @@ -1550,7 +1558,8 @@ def _resolve_packages( if isinstance(self._conn, MockServerConnection): # in local testing we don't resolve the packages, we just return what is added errors = [] - result_dict = self._packages.copy() + with self._lock: + result_dict = self._packages.copy() for pkg_name, _, pkg_req in package_dict.values(): if pkg_name in result_dict and str(pkg_req) != result_dict[pkg_name]: errors.append( @@ -1566,7 +1575,8 @@ def _resolve_packages( elif len(errors) > 0: raise RuntimeError(errors) - self._packages.update(result_dict) + with self._lock: + self._packages.update(result_dict) return list(result_dict.values()) package_table = "information_schema.packages" @@ -1581,9 +1591,12 @@ def _resolve_packages( # 'python-dateutil': 'python-dateutil==2.8.2'} # Add to packages dictionary. Make a copy of existing packages # dictionary to avoid modifying it during intermediate steps. - result_dict = ( - existing_packages_dict.copy() if existing_packages_dict is not None else {} - ) + with self._lock: + result_dict = ( + existing_packages_dict.copy() + if existing_packages_dict is not None + else {} + ) # Retrieve list of dependencies that need to be added dependency_packages = self._get_dependency_packages( @@ -1616,8 +1629,9 @@ def _resolve_packages( if include_pandas: extra_modules.append("pandas") - if existing_packages_dict is not None: - existing_packages_dict.update(result_dict) + with self._lock: + if existing_packages_dict is not None: + existing_packages_dict.update(result_dict) return list(result_dict.values()) + self._get_req_identifiers_list( extra_modules, result_dict ) @@ -2300,18 +2314,19 @@ def get_session_stage( Therefore, if you switch database or schema during the session, the stage will not be re-created in the new database or schema, and still references the stage in the old database or schema. """ - if not self._session_stage: - full_qualified_stage_name = self.get_fully_qualified_name_if_possible( - random_name_for_temp_object(TempObjectType.STAGE) - ) - self._run_query( - f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \ - stage if not exists {full_qualified_stage_name}", - is_ddl_on_temp_object=True, - statement_params=statement_params, - ) - # set the value after running the query to ensure atomicity - self._session_stage = full_qualified_stage_name + with self._lock: + if not self._session_stage: + full_qualified_stage_name = self.get_fully_qualified_name_if_possible( + random_name_for_temp_object(TempObjectType.STAGE) + ) + self._run_query( + f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \ + stage if not exists {full_qualified_stage_name}", + is_ddl_on_temp_object=True, + statement_params=statement_params, + ) + # set the value after running the query to ensure atomicity + self._session_stage = full_qualified_stage_name return f"{STAGE_PREFIX}{self._session_stage}" def _write_modin_pandas_helper( From e75dde123dbcf33c8071236436495c0214be62fc Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 14:52:58 -0700 Subject: [PATCH 03/62] init --- .../snowpark/_internal/server_connection.py | 11 ++++++++++- src/snowflake/snowpark/session.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 56445fc31b2..0ab0ae490fa 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -8,6 +8,7 @@ import inspect import os import sys +import threading import time from logging import getLogger from typing import ( @@ -154,6 +155,8 @@ def __init__( options: Dict[str, Union[int, str]], conn: Optional[SnowflakeConnection] = None, ) -> None: + self._lock = threading.RLock() + self._thread_stored = threading.local() self._lower_case_parameters = {k.lower(): v for k, v in options.items()} self._add_application_parameters() self._conn = conn if conn else connect(**self._lower_case_parameters) @@ -170,8 +173,8 @@ def __init__( if "password" in self._lower_case_parameters: self._lower_case_parameters["password"] = None - self._cursor = self._conn.cursor() self._telemetry_client = TelemetryClient(self._conn) + # TODO: protect _query_listener self._query_listener: Set[QueryHistory] = set() # The session in this case refers to a Snowflake session, not a # Snowpark session @@ -183,6 +186,12 @@ def __init__( "_skip_upload_on_content_match" in signature.parameters ) + @property + def _cursor(self) -> SnowflakeCursor: + if not hasattr(self._thread_stored, "cursor"): + self._thread_stored.cursor = self._conn.cursor() + return self._thread_stored.cursor + def _add_application_parameters(self) -> None: if PARAM_APPLICATION not in self._lower_case_parameters: # Mirrored from snowflake-connector-python/src/snowflake/connector/connection.py#L295 diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 607cd047f2b..1c436360fb4 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -540,9 +540,6 @@ def __init__( ) self._file = FileOperation(self) self._lineage = Lineage(self) - self._analyzer = ( - Analyzer(self) if isinstance(conn, ServerConnection) else MockAnalyzer(self) - ) self._sql_simplifier_enabled: bool = ( self._conn._get_client_side_session_parameter( _PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING, True @@ -608,6 +605,16 @@ def _generate_new_action_id(self) -> int: self._last_action_id += 1 return self._last_action_id + @property + def _analyzer(self) -> Analyzer: + if not hasattr(self._thread_store, "analyzer"): + self._thread_store.analyzer = ( + Analyzer(self) + if isinstance(self._conn, ServerConnection) + else MockAnalyzer(self) + ) + return self._thread_store.analyzer + def close(self) -> None: """Close this session.""" if is_in_stored_procedure(): From 68a8c1c7186df5892a664bc297341a97e5461a3c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 15:22:53 -0700 Subject: [PATCH 04/62] make query listener thread-safe --- .../snowpark/_internal/server_connection.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 0ab0ae490fa..7d51f1189c4 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -174,7 +174,6 @@ def __init__( if "password" in self._lower_case_parameters: self._lower_case_parameters["password"] = None self._telemetry_client = TelemetryClient(self._conn) - # TODO: protect _query_listener self._query_listener: Set[QueryHistory] = set() # The session in this case refers to a Snowflake session, not a # Snowpark session @@ -219,10 +218,12 @@ def _add_application_parameters(self) -> None: ] = get_version() def add_query_listener(self, listener: QueryHistory) -> None: - self._query_listener.add(listener) + with self._lock: + self._query_listener.add(listener) def remove_query_listener(self, listener: QueryHistory) -> None: - self._query_listener.remove(listener) + with self._lock: + self._query_listener.remove(listener) def close(self) -> None: if self._conn: @@ -369,8 +370,9 @@ def upload_stream( raise ex def notify_query_listeners(self, query_record: QueryRecord) -> None: - for listener in self._query_listener: - listener._add_query(query_record) + with self._lock: + for listener in self._query_listener: + listener._add_query(query_record) def execute_and_notify_query_listener( self, query: str, **kwargs: Any From 31a57343e348115f4944cb94d49f02d0c58b3288 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 16:21:57 -0700 Subject: [PATCH 05/62] Fix query_tag and last_action_id --- src/snowflake/snowpark/session.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1c436360fb4..004f0c65361 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -602,8 +602,9 @@ def __str__(self): ) def _generate_new_action_id(self) -> int: - self._last_action_id += 1 - return self._last_action_id + with self._lock: + self._last_action_id += 1 + return self._last_action_id @property def _analyzer(self) -> Analyzer: @@ -805,7 +806,8 @@ def cancel_all(self) -> None: This does not affect any action methods called in the future. """ _logger.info("Canceling all running queries") - self._last_canceled_id = self._last_action_id + with self._lock: + self._last_canceled_id = self._last_action_id if not isinstance(self._conn, MockServerConnection): self._conn.run_query( f"select system$cancel_all_queries({self._session_id})" @@ -1910,11 +1912,12 @@ def query_tag(self) -> Optional[str]: @query_tag.setter def query_tag(self, tag: str) -> None: - if tag: - self._conn.run_query(f"alter session set query_tag = {str_to_sql(tag)}") - else: - self._conn.run_query("alter session unset query_tag") - self._query_tag = tag + with self._lock: + if tag: + self._conn.run_query(f"alter session set query_tag = {str_to_sql(tag)}") + else: + self._conn.run_query("alter session unset query_tag") + self._query_tag = tag def _get_remote_query_tag(self) -> None: """ From b4dadda2ad2392c6a5d0bbf898cbeb5139021b87 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 16:56:57 -0700 Subject: [PATCH 06/62] core updates done --- src/snowflake/snowpark/dataframe.py | 5 +++-- src/snowflake/snowpark/session.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 25ce987fd78..20cf07334f5 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -2027,12 +2027,13 @@ def _union_by_name_internal( ] names = right_project_list + not_found_attrs - if self._session.sql_simplifier_enabled and other._select_statement: + sql_simplifier_enabled = self._session.sql_simplifier_enabled + if sql_simplifier_enabled and other._select_statement: right_child = self._with_plan(other._select_statement.select(names)) else: right_child = self._with_plan(Project(names, other._plan)) - if self._session.sql_simplifier_enabled: + if sql_simplifier_enabled: df = self._with_plan( self._select_statement.set_operator( right_child._select_statement diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 004f0c65361..0a9fdd8a83d 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3020,8 +3020,9 @@ def get_fully_qualified_name_if_possible(self, name: str) -> str: """ Returns the fully qualified object name if current database/schema exists, otherwise returns the object name """ - database = self.get_current_database() - schema = self.get_current_schema() + with self._lock: + database = self.get_current_database() + schema = self.get_current_schema() if database and schema: return f"{database}.{schema}.{name}" From b8c64960456cdd77789b88b14378865c1c41de0e Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 11:00:11 -0700 Subject: [PATCH 07/62] Add tests --- src/snowflake/snowpark/session.py | 25 +++---- tests/integ/test_multithreading.py | 109 +++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 12 deletions(-) create mode 100644 tests/integ/test_multithreading.py diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 0a9fdd8a83d..3a099560bc4 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2310,18 +2310,19 @@ def get_session_stage( Therefore, if you switch database or schema during the session, the stage will not be re-created in the new database or schema, and still references the stage in the old database or schema. """ - if not self._session_stage: - full_qualified_stage_name = self.get_fully_qualified_name_if_possible( - random_name_for_temp_object(TempObjectType.STAGE) - ) - self._run_query( - f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \ - stage if not exists {full_qualified_stage_name}", - is_ddl_on_temp_object=True, - statement_params=statement_params, - ) - # set the value after running the query to ensure atomicity - self._session_stage = full_qualified_stage_name + with self._lock: + if not self._session_stage: + full_qualified_stage_name = self.get_fully_qualified_name_if_possible( + random_name_for_temp_object(TempObjectType.STAGE) + ) + self._run_query( + f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \ + stage if not exists {full_qualified_stage_name}", + is_ddl_on_temp_object=True, + statement_params=statement_params, + ) + # set the value after running the query to ensure atomicity + self._session_stage = full_qualified_stage_name return f"{STAGE_PREFIX}{self._session_stage}" def _write_modin_pandas_helper( diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py new file mode 100644 index 00000000000..d2a966217a6 --- /dev/null +++ b/tests/integ/test_multithreading.py @@ -0,0 +1,109 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from concurrent.futures import ThreadPoolExecutor, as_completed +from unittest.mock import patch + +import pytest + +from snowflake.snowpark.functions import lit +from snowflake.snowpark.row import Row +from tests.utils import IS_IN_STORED_PROC, Utils + + +def test_concurrent_select_queries(session): + def run_select(session_, thread_id): + df = session_.sql(f"SELECT {thread_id} as A") + assert df.collect()[0][0] == thread_id + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(run_select, session, i) + + +def test_concurrent_dataframe_operations(session): + try: + table_name = Utils.random_table_name() + data = [(i, 11 * i) for i in range(10)] + df = session.create_dataframe(data, ["A", "B"]) + df.write.save_as_table(table_name, table_type="temporary") + + def run_dataframe_operation(session_, thread_id): + df = session_.table(table_name) + df = df.filter(df.a == lit(thread_id)) + df = df.with_column("C", df.b + 100 * df.a) + df = df.rename(df.a, "D").limit(1) + return df + + dfs = [] + with ThreadPoolExecutor(max_workers=10) as executor: + df_futures = [ + executor.submit(run_dataframe_operation, session, i) for i in range(10) + ] + + for future in as_completed(df_futures): + dfs.append(future.result()) + + main_df = dfs[0] + for df in dfs[1:]: + main_df = main_df.union(df) + + Utils.check_answer( + main_df, [Row(D=i, B=11 * i, C=11 * i + 100 * i) for i in range(10)] + ) + + finally: + Utils.drop_table(session, table_name) + + +def test_query_listener(session): + def run_select(session_, thread_id): + session_.sql(f"SELECT {thread_id} as A").collect() + + with session.query_history() as history: + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(run_select, session, i) + + queries_sent = [query.sql_text for query in history.queries] + assert len(queries_sent) == 10 + for i in range(10): + assert f"SELECT {i} as A" in queries_sent + + +@pytest.mark.skipif( + IS_IN_STORED_PROC, reason="show parameters is not supported in stored procedure" +) +def test_query_tagging(session): + def set_query_tag(session_, thread_id): + session_.query_tag = f"tag_{thread_id}" + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(set_query_tag, session, i) + + actual_query_tag = session.sql("SHOW PARAMETERS LIKE 'QUERY_TAG'").collect()[0][1] + assert actual_query_tag == session.query_tag + + +def test_session_stage_created_once(session): + with patch.object( + session._conn, "run_query", wraps=session._conn.run_query + ) as patched_run_query: + with ThreadPoolExecutor(max_workers=10) as executor: + for _ in range(10): + executor.submit(session.get_session_stage) + + assert patched_run_query.call_count == 1 + + +def test_action_ids_are_unique(session): + with ThreadPoolExecutor(max_workers=10) as executor: + action_ids = set() + futures = [executor.submit(session._generate_new_action_id) for _ in range(10)] + + for future in as_completed(futures): + action_ids.add(future.result()) + + assert len(action_ids) == 10 From f39837e29ae79e1e26536cda2577349c45033017 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 12:01:12 -0700 Subject: [PATCH 08/62] Fix local tests --- tests/integ/test_multithreading.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index d2a966217a6..164a4b7b590 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -57,6 +57,11 @@ def run_dataframe_operation(session_, thread_id): Utils.drop_table(session, table_name) +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="SQL query and query listeners are not supported", + run=False, +) def test_query_listener(session): def run_select(session_, thread_id): session_.sql(f"SELECT {thread_id} as A").collect() @@ -72,6 +77,11 @@ def run_select(session_, thread_id): assert f"SELECT {i} as A" in queries_sent +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="Query tag is a SQL feature", + run=False, +) @pytest.mark.skipif( IS_IN_STORED_PROC, reason="show parameters is not supported in stored procedure" ) @@ -87,6 +97,11 @@ def set_query_tag(session_, thread_id): assert actual_query_tag == session.query_tag +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="SQL query is not supported", + run=False, +) def test_session_stage_created_once(session): with patch.object( session._conn, "run_query", wraps=session._conn.run_query From 37c041936331de5c97fe32702f3cae8aff83ccd0 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 15:06:48 -0700 Subject: [PATCH 09/62] add file IO tests --- tests/integ/conftest.py | 23 ++++++- .../integ/scala/test_file_operation_suite.py | 17 +---- tests/integ/test_multithreading.py | 69 ++++++++++++++++++- 3 files changed, 91 insertions(+), 18 deletions(-) diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index ec619605e66..319abb137f4 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -13,7 +13,13 @@ from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.mock._connection import MockServerConnection from tests.parameters import CONNECTION_PARAMETERS -from tests.utils import TEST_SCHEMA, Utils, running_on_jenkins, running_on_public_ci +from tests.utils import ( + TEST_SCHEMA, + TestFiles, + Utils, + running_on_jenkins, + running_on_public_ci, +) def print_help() -> None: @@ -235,3 +241,18 @@ def temp_schema(connection, session, local_testing_mode) -> None: ) yield temp_schema_name cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}") + + +@pytest.fixture(scope="module") +def temp_stage(session, resources_path, local_testing_mode): + tmp_stage_name = Utils.random_stage_name() + test_files = TestFiles(resources_path) + + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) + Utils.upload_to_stage( + session, tmp_stage_name, test_files.test_file_parquet, compress=False + ) + yield tmp_stage_name + if not local_testing_mode: + Utils.drop_stage(session, tmp_stage_name) diff --git a/tests/integ/scala/test_file_operation_suite.py b/tests/integ/scala/test_file_operation_suite.py index 2dc424dde09..82a1722a729 100644 --- a/tests/integ/scala/test_file_operation_suite.py +++ b/tests/integ/scala/test_file_operation_suite.py @@ -14,7 +14,7 @@ SnowparkSQLException, SnowparkUploadFileException, ) -from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, TestFiles, Utils +from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, Utils def random_alphanumeric_name(): @@ -74,21 +74,6 @@ def path4(temp_source_directory): yield filename -@pytest.fixture(scope="module") -def temp_stage(session, resources_path, local_testing_mode): - tmp_stage_name = Utils.random_stage_name() - test_files = TestFiles(resources_path) - - if not local_testing_mode: - Utils.create_stage(session, tmp_stage_name, is_temporary=True) - Utils.upload_to_stage( - session, tmp_stage_name, test_files.test_file_parquet, compress=False - ) - yield tmp_stage_name - if not local_testing_mode: - Utils.drop_stage(session, tmp_stage_name) - - def test_put_with_one_file( session, temp_stage, path1, path2, path3, local_testing_mode ): diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 164a4b7b590..10fcc6ef70d 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -2,6 +2,9 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import hashlib +import os +import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed from unittest.mock import patch @@ -9,7 +12,7 @@ from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row -from tests.utils import IS_IN_STORED_PROC, Utils +from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils def test_concurrent_select_queries(session): @@ -122,3 +125,67 @@ def test_action_ids_are_unique(session): action_ids.add(future.result()) assert len(action_ids) == 10 + + +@pytest.mark.parametrize("use_stream", [True, False]) +def test_file_io(session, resources_path, temp_stage, use_stream): + stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}" + stage_with_prefix = f"@{temp_stage}/{stage_prefix}/" + test_files = TestFiles(resources_path) + + resources_files = [ + test_files.test_file_csv, + test_files.test_file2_csv, + test_files.test_file_json, + test_files.test_file_csv_header, + test_files.test_file_csv_colon, + test_files.test_file_csv_quotes, + test_files.test_file_csv_special_format, + test_files.test_file_json_special_format, + test_files.test_file_csv_quotes_special, + test_files.test_concat_file1_csv, + test_files.test_concat_file2_csv, + ] + + def get_file_hash(fd): + return hashlib.md5(fd.read()).hexdigest() + + def put_and_get_file(upload_file_path, download_dir): + if use_stream: + with open(upload_file_path, "rb") as fd: + results = session.file.put_stream( + fd, stage_with_prefix, auto_compress=False, overwrite=False + ) + else: + results = session.file.put( + upload_file_path, + stage_with_prefix, + auto_compress=False, + overwrite=False, + ) + # assert file is uploaded successfully + assert len(results) == 1 + assert results[0].status == "UPLOADED" + + stage_file_name = f"{stage_with_prefix}{os.path.basename(upload_file_path)}" + if use_stream: + fd = session.file.get_stream(stage_file_name, download_dir) + with open(upload_file_path, "rb") as upload_fd: + assert get_file_hash(upload_fd) == get_file_hash(fd) + + else: + results = session.file.get(stage_file_name, download_dir) + # assert file is downloaded successfully + assert len(results) == 1 + assert results[0].status == "DOWNLOADED" + download_file_path = results[0].file + # assert two files are identical + with open(upload_file_path, "rb") as upload_fd, open( + download_file_path, "rb" + ) as download_fd: + assert get_file_hash(upload_fd) == get_file_hash(download_fd) + + with tempfile.TemporaryDirectory() as download_dir: + with ThreadPoolExecutor(max_workers=10) as executor: + for file_path in resources_files: + executor.submit(put_and_get_file, file_path, download_dir) From a083989df8663f87d1b5fc0fed62b75d5fa1b576 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 17:20:04 -0700 Subject: [PATCH 10/62] make session._runtime_version_from_requirement safe --- src/snowflake/snowpark/_internal/udf_utils.py | 24 +++++++------------ src/snowflake/snowpark/stored_procedure.py | 11 ++++++--- src/snowflake/snowpark/udaf.py | 9 +++++-- src/snowflake/snowpark/udf.py | 9 +++++-- src/snowflake/snowpark/udtf.py | 9 +++++-- 5 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 4012ad9ff6e..2c36c8e6091 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1122,13 +1122,10 @@ def resolve_imports_and_packages( # Upload closure to stage if it is beyond inline closure size limit handler = inline_code = upload_file_stage_location = None - custom_python_runtime_version_allowed = False + # As cloudpickle is being used, we cannot allow a custom runtime + custom_python_runtime_version_allowed = not isinstance(func, Callable) if session is not None: if isinstance(func, Callable): - custom_python_runtime_version_allowed = ( - False # As cloudpickle is being used, we cannot allow a custom runtime - ) - # generate a random name for udf py file # and we compress it first then upload it udf_file_name_base = f"udf_py_{random_number()}" @@ -1173,7 +1170,6 @@ def resolve_imports_and_packages( upload_file_stage_location = None handler = _DEFAULT_HANDLER_NAME else: - custom_python_runtime_version_allowed = True udf_file_name = os.path.basename(func[0]) # for a compressed file, it might have multiple extensions # and we should remove all extensions @@ -1198,11 +1194,6 @@ def resolve_imports_and_packages( skip_upload_on_content_match=skip_upload_on_content_match, ) all_urls.append(upload_file_stage_location) - else: - if isinstance(func, Callable): - custom_python_runtime_version_allowed = False - else: - custom_python_runtime_version_allowed = True # build imports and packages string all_imports = ",".join( @@ -1245,12 +1236,13 @@ def create_python_udf_or_sp( statement_params: Optional[Dict[str, str]] = None, comment: Optional[str] = None, native_app_params: Optional[Dict[str, Any]] = None, + runtime_version: Optional[str] = None, ) -> None: - with session._lock: - if session is not None and session._runtime_version_from_requirement: - runtime_version = session._runtime_version_from_requirement - else: - runtime_version = f"{sys.version_info[0]}.{sys.version_info[1]}" + runtime_version = ( + f"{sys.version_info[0]}.{sys.version_info[1]}" + if not runtime_version + else runtime_version + ) if replace and if_not_exists: raise ValueError("options replace and if_not_exists are incompatible") diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index 92e63f6a76a..c2fcb024196 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -820,11 +820,15 @@ def _do_register_sp( force_inline_code=force_inline_code, ) - if (not custom_python_runtime_version_allowed) and (self._session is not None): - check_python_runtime_version( + runtime_version_from_requirement = None + if self._session is not None: + runtime_version_from_requirement = ( self._session._runtime_version_from_requirement ) + if not custom_python_runtime_version_allowed: + check_python_runtime_version(runtime_version_from_requirement) + anonymous_sp_sql = None if anonymous: anonymous_sp_sql = generate_anonymous_python_sp_sql( @@ -838,7 +842,7 @@ def _do_register_sp( raw_imports=imports, inline_python_code=code, strict=strict, - runtime_version=self._session._runtime_version_from_requirement, + runtime_version=runtime_version_from_requirement, external_access_integrations=external_access_integrations, secrets=secrets, native_app_params=native_app_params, @@ -870,6 +874,7 @@ def _do_register_sp( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a stored procedure # (e.g., a dependency might not be found on the stage), diff --git a/src/snowflake/snowpark/udaf.py b/src/snowflake/snowpark/udaf.py index fdc3d555281..889b5b62915 100644 --- a/src/snowflake/snowpark/udaf.py +++ b/src/snowflake/snowpark/udaf.py @@ -680,11 +680,15 @@ def _do_register_udaf( is_permanent=is_permanent, ) - if (not custom_python_runtime_version_allowed) and (self._session is not None): - check_python_runtime_version( + runtime_version_from_requirement = None + if self._session is not None: + runtime_version_from_requirement = ( self._session._runtime_version_from_requirement ) + if not custom_python_runtime_version_allowed: + check_python_runtime_version(runtime_version_from_requirement) + raised = False try: create_python_udf_or_sp( @@ -710,6 +714,7 @@ def _do_register_udaf( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a udaf # (e.g., a dependency might not be found on the stage), diff --git a/src/snowflake/snowpark/udf.py b/src/snowflake/snowpark/udf.py index 74278e48d30..b71a40263fd 100644 --- a/src/snowflake/snowpark/udf.py +++ b/src/snowflake/snowpark/udf.py @@ -873,11 +873,15 @@ def _do_register_udf( is_permanent=is_permanent, ) - if (not custom_python_runtime_version_allowed) and (self._session is not None): - check_python_runtime_version( + runtime_version_from_requirement = None + if self._session is not None: + runtime_version_from_requirement = ( self._session._runtime_version_from_requirement ) + if not custom_python_runtime_version_allowed: + check_python_runtime_version(runtime_version_from_requirement) + raised = False try: create_python_udf_or_sp( @@ -905,6 +909,7 @@ def _do_register_udf( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a udf # (e.g., a dependency might not be found on the stage), diff --git a/src/snowflake/snowpark/udtf.py b/src/snowflake/snowpark/udtf.py index 03b71aa5f75..856007cdb64 100644 --- a/src/snowflake/snowpark/udtf.py +++ b/src/snowflake/snowpark/udtf.py @@ -934,11 +934,15 @@ def _do_register_udtf( is_permanent=is_permanent, ) - if (not custom_python_runtime_version_allowed) and (self._session is not None): - check_python_runtime_version( + runtime_version_from_requirement = None + if self._session is not None: + runtime_version_from_requirement = ( self._session._runtime_version_from_requirement ) + if not custom_python_runtime_version_allowed: + check_python_runtime_version(runtime_version_from_requirement) + raised = False try: create_python_udf_or_sp( @@ -966,6 +970,7 @@ def _do_register_udtf( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a udtf # (e.g., a dependency might not be found on the stage), From 947d3847498e5749f759bff597d93dea40561b37 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 10:35:04 -0700 Subject: [PATCH 11/62] add sp/udf concurrent tests --- src/snowflake/snowpark/session.py | 2 +- tests/integ/test_multithreading.py | 218 +++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 874ff5b72b6..a233ce26683 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -944,8 +944,8 @@ def clear_imports(self) -> None: with self._lock: self._import_paths.clear() + @staticmethod def _resolve_import_path( - self, path: str, import_path: Optional[str] = None, chunk_size: int = 8192, diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 10fcc6ef70d..dd2c1da54e5 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -10,6 +10,18 @@ import pytest +from snowflake.snowpark.types import IntegerType + +try: + import dateutil + + # six is the dependency of dateutil + import six + + is_dateutil_available = True +except ImportError: + is_dateutil_available = False + from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils @@ -189,3 +201,209 @@ def put_and_get_file(upload_file_path, download_dir): with ThreadPoolExecutor(max_workers=10) as executor: for file_path in resources_files: executor.submit(put_and_get_file, file_path, download_dir) + + +def test_concurrent_add_packages(session): + # this is a list of packages available in snowflake anaconda. If this + # test fails due to packages not being available, please update the list + package_list = { + "graphviz", + "numpy", + "pandas", + "scipy", + "scikit-learn", + "matplotlib", + } + + try: + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [ + executor.submit(session.add_packages, package) + for package in package_list + ] + + for future in as_completed(futures): + future.result() + + assert session.get_packages() == { + package: package for package in package_list + } + finally: + session.clear_packages() + + +def test_concurrent_remove_package(session): + def remove_package(session_, package_name): + try: + session_.remove_package(package_name) + return True + except ValueError: + return False + except Exception as e: + raise e + + try: + session.add_packages("numpy") + with ThreadPoolExecutor(max_workers=10) as executor: + + futures = [ + executor.submit(remove_package, session, "numpy") for _ in range(10) + ] + success_count, failure_count = 0, 0 + for future in as_completed(futures): + if future.result(): + success_count += 1 + else: + failure_count += 1 + + # assert that only one thread was able to remove the package + assert success_count == 1 + assert failure_count == 9 + finally: + session.clear_packages() + + +@pytest.mark.skipif(not is_dateutil_available, reason="dateutil is not available") +def test_concurrent_add_import(session, resources_path): + test_files = TestFiles(resources_path) + import_files = [ + test_files.test_udf_py_file, + os.path.relpath(test_files.test_udf_py_file), + test_files.test_udf_directory, + os.path.relpath(test_files.test_udf_directory), + six.__file__, + os.path.relpath(six.__file__), + os.path.dirname(dateutil.__file__), + ] + try: + with ThreadPoolExecutor(max_workers=10) as executor: + for file in import_files: + executor.submit( + session.add_import, + file, + ) + + assert set(session.get_imports()) == { + os.path.abspath(file) for file in import_files + } + finally: + session.clear_imports() + + +def test_concurrent_remove_import(session, resources_path): + test_files = TestFiles(resources_path) + + def remove_import(session_, import_file): + try: + session_.remove_import(import_file) + return True + except KeyError: + return False + except Exception as e: + raise e + + try: + session.add_import(test_files.test_udf_py_file) + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [ + executor.submit(remove_import, session, test_files.test_udf_py_file) + for _ in range(10) + ] + + success_count, failure_count = 0, 0 + for future in as_completed(futures): + if future.result(): + success_count += 1 + else: + failure_count += 1 + + # assert that only one thread was able to remove the import + assert success_count == 1 + assert failure_count == 9 + finally: + session.clear_imports() + + +def test_concurrent_sp_register(session, tmpdir): + try: + session.add_packages("snowflake-snowpark-python") + + def register_and_test_sp(session_, thread_id): + prefix = Utils.random_alphanumeric_str(10) + sp_file_path = os.path.join(tmpdir, f"{prefix}_add_{thread_id}.py") + sproc_body = f""" +from snowflake.snowpark import Session +from snowflake.snowpark.functions import ( + col, + lit +) +def add_{thread_id}(session_: Session, x: int) -> int: + return ( + session_.create_dataframe([[x, ]], schema=["x"]) + .select(col("x") + lit({thread_id})) + .collect()[0][0] + ) +""" + with open(sp_file_path, "w") as f: + f.write(sproc_body) + f.flush() + + add_sp_from_file = session_.sproc.register_from_file( + sp_file_path, f"add_{thread_id}" + ) + add_sp = session_.sproc.register( + lambda sess_, x: sess_.sql(f"select {x} + {thread_id}").collect()[0][0], + return_type=IntegerType(), + input_types=[IntegerType()], + ) + + assert add_sp_from_file(1) == thread_id + 1 + assert add_sp(1) == thread_id + 1 + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(register_and_test_sp, session, i) + finally: + session.clear_packages() + + +def test_concurrent_udf_register(session, tmpdir): + df = session.range(-5, 5).to_df("a") + + def register_and_test_udf(session_, thread_id): + prefix = Utils.random_alphanumeric_str(10) + file_path = os.path.join(tmpdir, f"{prefix}_add_{thread_id}.py") + with open(file_path, "w") as f: + func = f""" +def add_{thread_id}(x: int) -> int: + return x + {thread_id} +""" + f.write(func) + f.flush() + add_i_udf_from_file = session_.udf.register_from_file( + file_path, f"add_{thread_id}" + ) + add_i_udf = session_.udf.register( + lambda x: x + thread_id, + return_type=IntegerType(), + input_types=[IntegerType()], + ) + + Utils.check_answer( + df.select(add_i_udf(df.a), add_i_udf_from_file(df.a)), + [(thread_id + i, thread_id + i) for i in range(-5, 5)], + ) + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(register_and_test_udf, session, i) + + +@pytest.mark.parametrize("from_file", [True, False]) +def test_concurrent_udtf_register(session, from_file): + pass + + +@pytest.mark.parametrize("from_file", [True, False]) +def test_concurrent_udaf_register(session, from_file): + pass From fd51720ae43bc8112082c576707b35eddb7fd8a4 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 10:47:34 -0700 Subject: [PATCH 12/62] fix broken test --- src/snowflake/snowpark/_internal/udf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 2c36c8e6091..07635c8de8a 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1076,6 +1076,7 @@ def resolve_imports_and_packages( ) ) + all_urls = [] if session is not None: import_only_stage = ( unwrap_stage_location_single_quote(stage_location) @@ -1089,7 +1090,6 @@ def resolve_imports_and_packages( else session.get_session_stage(statement_params=statement_params) ) - all_urls = [] if imports: udf_level_imports = {} for udf_import in imports: From 3077853b6e89bf19b9352fbee02a90c5c59c5023 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 13:43:28 -0700 Subject: [PATCH 13/62] add udtf/udaf tests --- tests/integ/test_multithreading.py | 101 +++++++++++++++++++++++++++-- 1 file changed, 95 insertions(+), 6 deletions(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index dd2c1da54e5..9f8b5bef293 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -6,10 +6,12 @@ import os import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Tuple # noqa: F401 from unittest.mock import patch import pytest +from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType try: @@ -399,11 +401,98 @@ def add_{thread_id}(x: int) -> int: executor.submit(register_and_test_udf, session, i) -@pytest.mark.parametrize("from_file", [True, False]) -def test_concurrent_udtf_register(session, from_file): - pass +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="UDTFs is not supported in local testing mode", + run=False, +) +def test_concurrent_udtf_register(session, tmpdir): + def register_and_test_udtf(session_, thread_id): + udtf_body = f""" +from typing import List, Tuple + +class UDTFEcho: + def process( + self, + num: int, + ) -> List[Tuple[int]]: + return [(num + {thread_id},)] +""" + prefix = Utils.random_alphanumeric_str(10) + file_path = os.path.join(tmpdir, f"{prefix}_udtf_echo_{thread_id}.py") + with open(file_path, "w") as f: + f.write(udtf_body) + f.flush() + + d = {} + exec(udtf_body, {**globals(), **locals()}, d) + echo_udtf_from_file = session_.udtf.register_from_file( + file_path, "UDTFEcho", output_schema=["num"] + ) + echo_udtf = session_.udtf.register(d["UDTFEcho"], output_schema=["num"]) + + df_local = session.table_function(echo_udtf(lit(1))) + df_from_file = session.table_function(echo_udtf_from_file(lit(1))) + assert df_local.collect() == [(thread_id + 1,)] + assert df_from_file.collect() == [(thread_id + 1,)] + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(register_and_test_udtf, session, i) + + +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="UDAFs is not supported in local testing mode", + run=False, +) +def test_concurrent_udaf_register(session: Session, tmpdir): + df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b") + def register_and_test_udaf(session_, thread_id): + udaf_body = f""" +class OffsetSumUDAFHandler: + def __init__(self) -> None: + self._sum = 0 -@pytest.mark.parametrize("from_file", [True, False]) -def test_concurrent_udaf_register(session, from_file): - pass + @property + def aggregate_state(self): + return self._sum + + def accumulate(self, input_value): + self._sum += input_value + + def merge(self, other_sum): + self._sum += other_sum + + def finish(self): + return self._sum + {thread_id} + """ + prefix = Utils.random_alphanumeric_str(10) + file_path = os.path.join(tmpdir, f"{prefix}_udaf_{thread_id}.py") + with open(file_path, "w") as f: + f.write(udaf_body) + f.flush() + d = {} + exec(udaf_body, {**globals(), **locals()}, d) + + offset_sum_udaf_from_file = session_.udaf.register_from_file( + file_path, + "OffsetSumUDAFHandler", + return_type=IntegerType(), + input_types=[IntegerType()], + ) + offset_sum_udaf = session_.udaf.register( + d["OffsetSumUDAFHandler"], + return_type=IntegerType(), + input_types=[IntegerType()], + ) + + Utils.check_answer( + df.agg(offset_sum_udaf_from_file(df.a)), [Row(6 + thread_id)] + ) + Utils.check_answer(df.agg(offset_sum_udaf(df.a)), [Row(6 + thread_id)]) + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(register_and_test_udaf, session, i) From 65c3186437f17ad29af91096cec5be29424bd6f4 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 13:51:53 -0700 Subject: [PATCH 14/62] fix broken test --- tests/unit/test_udf_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index 12a49443539..c23755c14a3 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -5,6 +5,7 @@ import logging import os import pickle +import threading from unittest import mock import pytest @@ -249,6 +250,7 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): "random_package_one": "random_package_one", "random_package_two": "random_package_two", } + fake_session._lock = threading.RLock() result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) major, minor, patch = VERSION From 94412cf0eee0ab028d268c77adc993e3802fccdf Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 16:21:55 -0700 Subject: [PATCH 15/62] sql_simplifier, cte_optimization, eliminate_numeric, query_compilation_stage, plan_builder --- .../snowpark/_internal/analyzer/analyzer.py | 29 ++++++++- .../_internal/analyzer/snowflake_plan.py | 63 +++++++++++++------ .../_internal/compiler/plan_compiler.py | 28 +++++---- src/snowflake/snowpark/mock/_analyzer.py | 17 ++++- .../modin/plugin/_internal/generator_utils.py | 9 +-- src/snowflake/snowpark/session.py | 54 +++++++++------- 6 files changed, 137 insertions(+), 63 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 76e91b7da92..72d016bc5c1 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -166,6 +166,13 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None: self.generated_alias_maps = {} self.subquery_plans = [] self.alias_maps_to_use: Optional[Dict[uuid.UUID, str]] = None + self._eliminate_numeric_sql_value_cast_enabled: Optional[bool] = None + + @property + def eliminate_numeric_sql_value_cast_enabled(self) -> bool: + if self._eliminate_numeric_sql_value_cast_enabled is None: + return self.session.eliminate_numeric_sql_value_cast_enabled + return self._eliminate_numeric_sql_value_cast_enabled def analyze( self, @@ -264,7 +271,7 @@ def analyze( if isinstance(expr, MultipleExpression): block_expressions = [] for expression in expr.expressions: - if self.session.eliminate_numeric_sql_value_cast_enabled: + if self.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, @@ -283,7 +290,7 @@ def analyze( if isinstance(expr, InExpression): in_values = [] for expression in expr.values: - if self.session.eliminate_numeric_sql_value_cast_enabled: + if self.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, @@ -678,7 +685,7 @@ def binary_operator_extractor( df_aliased_col_name_to_real_col_name, parse_local_name=False, ) -> str: - if self.session.eliminate_numeric_sql_value_cast_enabled: + if self.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( expr.left, df_aliased_col_name_to_real_col_name, parse_local_name ) @@ -760,6 +767,18 @@ def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: self.subquery_plans = [] self.generated_alias_maps = {} + # To ensure that the context remain unchanged during resolving the plan, we + # read these values at the beginning and reset them at the end. + self.plan_builder._cte_optimization_enabled = ( + self.session.cte_optimization_enabled + ) + self.plan_builder._query_compilation_stage_enabled = ( + self.session._query_compilation_stage_enabled + ) + self._eliminate_numeric_sql_value_cast_enabled = ( + self.session.eliminate_numeric_sql_value_cast_enabled + ) + result = self.do_resolve(logical_plan) result.add_aliases(self.generated_alias_maps) @@ -767,6 +786,10 @@ def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: if self.subquery_plans: result = result.with_subqueries(self.subquery_plans) + self.plan_builder._cte_optimization_enabled = None + self.plan_builder._query_compilation_stage_enabled = None + self._eliminate_numeric_sql_value_cast_enabled = None + return result def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 9750deba3f9..25de5ba5718 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -312,15 +312,14 @@ def children_plan_nodes(self) -> List[Union["Selectable", "SnowflakePlan"]]: else: return [] - def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan": + def replace_repeated_subquery_with_cte( + self, cte_optimization_enabled: bool, query_compilation_stage_enabled: bool + ) -> "SnowflakePlan": # parameter protection # the common subquery elimination will be applied if cte_optimization is not enabled # and the new compilation stage is not enabled. When new compilation stage is enabled, # the common subquery elimination will be done through the new plan transformation. - if ( - not self.session._cte_optimization_enabled - or self.session._query_compilation_stage_enabled - ): + if not cte_optimization_enabled or query_compilation_stage_enabled: return self # if source_plan or placeholder_query is none, it must be a leaf node, @@ -452,7 +451,7 @@ def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): self._cumulative_node_complexity = value def __copy__(self) -> "SnowflakePlan": - if self.session._cte_optimization_enabled: + if self.cte_optimization_enabled: return SnowflakePlan( copy.deepcopy(self.queries) if self.queries else [], self.schema_query, @@ -531,6 +530,21 @@ def __init__( # on the optimized plan. During the final query generation, no schema query is needed, # this helps reduces un-necessary overhead for the describing call. self._skip_schema_query = skip_schema_query + # TODO: describe + self._cte_optimization_enabled: Optional[bool] = None + self._query_compilation_stage_enabled: Optional[bool] = None + + @property + def cte_optimization_enabled(self) -> bool: + if self._cte_optimization_enabled is None: + return self.session.cte_optimization_enabled + return self._cte_optimization_enabled + + @property + def query_compilation_stage_enabled(self) -> bool: + if self._query_compilation_stage_enabled is None: + return self.session._query_compilation_stage_enabled + return self._query_compilation_stage_enabled @SnowflakePlan.Decorator.wrap_exception def build( @@ -566,7 +580,7 @@ def build( placeholder_query = ( sql_generator(select_child._id) - if self.session._cte_optimization_enabled and select_child._id is not None + if self.cte_optimization_enabled and select_child._id is not None else None ) @@ -605,7 +619,7 @@ def build_binary( placeholder_query = ( sql_generator(select_left._id, select_right._id) - if self.session._cte_optimization_enabled + if self.cte_optimization_enabled and select_left._id is not None and select_right._id is not None else None @@ -636,10 +650,7 @@ def build_binary( post_actions.append(copy.copy(post_action)) referenced_ctes: Set[str] = set() - if ( - self.session.cte_optimization_enabled - and self.session._query_compilation_stage_enabled - ): + if self.cte_optimization_enabled and self.query_compilation_stage_enabled: # When the cte optimization and the new compilation stage is enabled, # the referred cte tables are propagated from left and right can have # duplicated queries if there is a common CTE block referenced by @@ -928,7 +939,9 @@ def save_as_table( column_definition_with_hidden_columns, ) - child = child.replace_repeated_subquery_with_cte() + child = child.replace_repeated_subquery_with_cte( + self.cte_optimization_enabled, self.query_compilation_stage_enabled + ) def get_create_table_as_select_plan(child: SnowflakePlan, replace, error): return self.build( @@ -1116,7 +1129,9 @@ def create_or_replace_view( if not is_sql_select_statement(child.queries[0].sql.lower().strip()): raise SnowparkClientExceptionMessages.PLAN_CREATE_VIEWS_FROM_SELECT_ONLY() - child = child.replace_repeated_subquery_with_cte() + child = child.replace_repeated_subquery_with_cte( + self.cte_optimization_enabled, self.query_compilation_stage_enabled + ) return self.build( lambda x: create_or_replace_view_statement(name, x, is_temp, comment), child, @@ -1159,7 +1174,9 @@ def create_or_replace_dynamic_table( # should never reach here raise ValueError(f"Unknown create mode: {create_mode}") # pragma: no cover - child = child.replace_repeated_subquery_with_cte() + child = child.replace_repeated_subquery_with_cte( + self.cte_optimization_enabled, self.query_compilation_stage_enabled + ) return self.build( lambda x: create_or_replace_dynamic_table_statement( name=name, @@ -1462,7 +1479,9 @@ def copy_into_location( header: bool = False, **copy_options: Optional[Any], ) -> SnowflakePlan: - query = query.replace_repeated_subquery_with_cte() + query = query.replace_repeated_subquery_with_cte( + self.cte_optimization_enabled, self.query_compilation_stage_enabled + ) return self.build( lambda x: copy_into_location( query=x, @@ -1489,7 +1508,9 @@ def update( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: if source_data: - source_data = source_data.replace_repeated_subquery_with_cte() + source_data = source_data.replace_repeated_subquery_with_cte( + self.cte_optimization_enabled, self.query_compilation_stage_enabled + ) return self.build( lambda x: update_statement( table_name, @@ -1520,7 +1541,9 @@ def delete( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: if source_data: - source_data = source_data.replace_repeated_subquery_with_cte() + source_data = source_data.replace_repeated_subquery_with_cte( + self.cte_optimization_enabled, self.query_compilation_stage_enabled + ) return self.build( lambda x: delete_statement( table_name, @@ -1549,7 +1572,9 @@ def merge( clauses: List[str], source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: - source_data = source_data.replace_repeated_subquery_with_cte() + source_data = source_data.replace_repeated_subquery_with_cte( + self.cte_optimization_enabled, self.query_compilation_stage_enabled + ) return self.build( lambda x: merge_statement(table_name, x, join_expr, clauses), source_data, diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 211b66820ec..b036e3635c0 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -48,6 +48,14 @@ class PlanCompiler: def __init__(self, plan: SnowflakePlan) -> None: self._plan = plan + current_session = self._plan.session + self.cte_optimization_enabled = current_session.cte_optimization_enabled + self.large_query_breakdown_enabled = ( + current_session.large_query_breakdown_enabled + ) + self.query_compilation_stage_enabled = ( + current_session._query_compilation_stage_enabled + ) def should_start_query_compilation(self) -> bool: """ @@ -67,11 +75,8 @@ def should_start_query_compilation(self) -> bool: return ( not isinstance(current_session._conn, MockServerConnection) and (self._plan.source_plan is not None) - and current_session._query_compilation_stage_enabled - and ( - current_session.cte_optimization_enabled - or current_session.large_query_breakdown_enabled - ) + and self.query_compilation_stage_enabled + and (self.cte_optimization_enabled or self.large_query_breakdown_enabled) ) def compile(self) -> Dict[PlanQueryType, List[Query]]: @@ -91,7 +96,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: # 3. apply each optimizations if needed # CTE optimization cte_start_time = time.time() - if self._plan.session.cte_optimization_enabled: + if self.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) @@ -104,7 +109,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: ] # Large query breakdown - if self._plan.session.large_query_breakdown_enabled: + if self.large_query_breakdown_enabled: large_query_breakdown = LargeQueryBreakdown( self._plan.session, query_generator, logical_plans ) @@ -126,8 +131,8 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: total_time = time.time() - start_time session = self._plan.session summary_value = { - TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.cte_optimization_enabled, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.large_query_breakdown_enabled, CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: ( COMPLEXITY_SCORE_LOWER_BOUND, COMPLEXITY_SCORE_UPPER_BOUND, @@ -148,8 +153,9 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: return queries else: final_plan = self._plan - if self._plan.session.cte_optimization_enabled: - final_plan = final_plan.replace_repeated_subquery_with_cte() + final_plan = final_plan.replace_repeated_subquery_with_cte( + self.cte_optimization_enabled, self.query_compilation_stage_enabled + ) return { PlanQueryType.QUERIES: final_plan.queries, PlanQueryType.POST_ACTIONS: final_plan.post_actions, diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index 666654917ea..6f02724e442 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -153,6 +153,13 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None: self.subquery_plans = [] self.alias_maps_to_use = None self._conn = self.session._conn + self._eliminate_numeric_sql_value_cast_enabled: Optional[bool] = None + + @property + def eliminate_numeric_sql_value_cast_enabled(self) -> bool: + if self._eliminate_numeric_sql_value_cast_enabled is None: + return self.session.eliminate_numeric_sql_value_cast_enabled + return self._eliminate_numeric_sql_value_cast_enabled def analyze( self, @@ -226,7 +233,7 @@ def analyze( if isinstance(expr, MultipleExpression): block_expressions = [] for expression in expr.expressions: - if self.session.eliminate_numeric_sql_value_cast_enabled: + if self.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( expression, expr_to_alias, @@ -245,7 +252,7 @@ def analyze( if isinstance(expr, InExpression): in_values = [] for expression in expr.values: - if self.session.eliminate_numeric_sql_value_cast_enabled: + if self.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( expression, expr_to_alias, @@ -553,7 +560,7 @@ def binary_operator_extractor( expr_to_alias: Dict[str, str], parse_local_name=False, ) -> str: - if self.session.eliminate_numeric_sql_value_cast_enabled: + if self.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( expr.left, expr_to_alias, parse_local_name ) @@ -618,8 +625,12 @@ def resolve( self.subquery_plans = [] if expr_to_alias is None: expr_to_alias = {} + self._eliminate_numeric_sql_value_cast_enabled = ( + self.session.eliminate_numeric_sql_value_cast_enabled + ) result = self.do_resolve(logical_plan, expr_to_alias) + self._eliminate_numeric_sql_value_cast_enabled = None return result def do_resolve( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/generator_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/generator_utils.py index e7ca753b04b..25cc6da5bc0 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/generator_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/generator_utils.py @@ -225,10 +225,11 @@ def generate_irregular_range( # TODO: SNOW-1646883 fix invalid identifier error when sql_simplifier_enabled is True session = get_active_session() - sql_simplifier_enabled = session.sql_simplifier_enabled - session.sql_simplifier_enabled = False - num_offsets = session.range(start=0, end=periods, step=1) - session.sql_simplifier_enabled = sql_simplifier_enabled + with session._lock: + sql_simplifier_enabled = session.sql_simplifier_enabled + session.sql_simplifier_enabled = False + num_offsets = session.range(start=0, end=periods, step=1) + session.sql_simplifier_enabled = sql_simplifier_enabled sf_date_or_time_part = _offset_name_to_sf_date_or_time_part(offset.name) dt_col = builtin("DATEADD")( sf_date_or_time_part, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a233ce26683..1866a6d629e 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -525,11 +525,6 @@ def __init__( self._udtf_registration = UDTFRegistration(self) self._udaf_registration = UDAFRegistration(self) - self._plan_builder = ( - SnowflakePlanBuilder(self) - if isinstance(self._conn, ServerConnection) - else MockSnowflakePlanBuilder(self) - ) self._last_action_id = 0 self._last_canceled_id = 0 self._use_scoped_temp_objects: bool = ( @@ -616,6 +611,16 @@ def _analyzer(self) -> Analyzer: ) return self._thread_store.analyzer + @property + def _plan_builder(self): + if not hasattr(self._thread_store, "plan_builder"): + self._thread_store.plan_builder = ( + SnowflakePlanBuilder(self) + if isinstance(self._conn, ServerConnection) + else MockSnowflakePlanBuilder(self) + ) + return self._thread_store.plan_builder + def close(self) -> None: """Close this session.""" if is_in_stored_procedure(): @@ -722,25 +727,27 @@ def custom_package_usage_config(self) -> Dict: @sql_simplifier_enabled.setter def sql_simplifier_enabled(self, value: bool) -> None: - self._conn._telemetry_client.send_sql_simplifier_telemetry( - self._session_id, value - ) - try: - self._conn._cursor.execute( - f"alter session set {_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING} = {value}" + with self._lock: + self._conn._telemetry_client.send_sql_simplifier_telemetry( + self._session_id, value ) - except Exception: - pass - self._sql_simplifier_enabled = value + try: + self._conn._cursor.execute( + f"alter session set {_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING} = {value}" + ) + except Exception: + pass + self._sql_simplifier_enabled = value @cte_optimization_enabled.setter @experimental_parameter(version="1.15.0") def cte_optimization_enabled(self, value: bool) -> None: - if value: - self._conn._telemetry_client.send_cte_optimization_telemetry( - self._session_id - ) - self._cte_optimization_enabled = value + with self._lock: + if value: + self._conn._telemetry_client.send_cte_optimization_telemetry( + self._session_id + ) + self._cte_optimization_enabled = value @eliminate_numeric_sql_value_cast_enabled.setter @experimental_parameter(version="1.20.0") @@ -748,10 +755,11 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: """Set the value for eliminate_numeric_sql_value_cast_enabled""" if value in [True, False]: - self._conn._telemetry_client.send_eliminate_numeric_sql_value_cast_telemetry( - self._session_id, value - ) - self._eliminate_numeric_sql_value_cast_enabled = value + with self._lock: + self._conn._telemetry_client.send_eliminate_numeric_sql_value_cast_telemetry( + self._session_id, value + ) + self._eliminate_numeric_sql_value_cast_enabled = value else: raise ValueError( "value for eliminate_numeric_sql_value_cast_enabled must be True or False!" From 638dd09e5f2e60a63b163551263289e6745c9fd2 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Mon, 16 Sep 2024 19:42:34 -0700 Subject: [PATCH 16/62] cover more configs --- src/snowflake/snowpark/session.py | 69 ++++++++++++++++++------------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1866a6d629e..f4f9b07b30c 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -331,38 +331,44 @@ def __init__(self, session: "Session", conf: Dict[str, Any]) -> None: "use_constant_subquery_alias": True, "flatten_select_after_filter_and_orderby": True, } # For config that's temporary/to be removed soon + self._lock = self._session._lock for key, val in conf.items(): if self.is_mutable(key): self.set(key, val) def get(self, key: str, default=None) -> Any: - if hasattr(Session, key): - return getattr(self._session, key) - if hasattr(self._session._conn._conn, key): - return getattr(self._session._conn._conn, key) - return self._conf.get(key, default) + with self._lock: + if hasattr(Session, key): + return getattr(self._session, key) + if hasattr(self._session._conn._conn, key): + return getattr(self._session._conn._conn, key) + return self._conf.get(key, default) def is_mutable(self, key: str) -> bool: - if hasattr(Session, key) and isinstance(getattr(Session, key), property): - return getattr(Session, key).fset is not None - if hasattr(SnowflakeConnection, key) and isinstance( - getattr(SnowflakeConnection, key), property - ): - return getattr(SnowflakeConnection, key).fset is not None - return key in self._conf + with self._lock: + if hasattr(Session, key) and isinstance( + getattr(Session, key), property + ): + return getattr(Session, key).fset is not None + if hasattr(SnowflakeConnection, key) and isinstance( + getattr(SnowflakeConnection, key), property + ): + return getattr(SnowflakeConnection, key).fset is not None + return key in self._conf def set(self, key: str, value: Any) -> None: - if self.is_mutable(key): - if hasattr(Session, key): - setattr(self._session, key, value) - if hasattr(SnowflakeConnection, key): - setattr(self._session._conn._conn, key, value) - if key in self._conf: - self._conf[key] = value - else: - raise AttributeError( - f'Configuration "{key}" does not exist or is not mutable in runtime' - ) + with self._lock: + if self.is_mutable(key): + if hasattr(Session, key): + setattr(self._session, key, value) + if hasattr(SnowflakeConnection, key): + setattr(self._session._conn._conn, key, value) + if key in self._conf: + self._conf[key] = value + else: + raise AttributeError( + f'Configuration "{key}" does not exist or is not mutable in runtime' + ) class SessionBuilder: """ @@ -794,10 +800,11 @@ def large_query_breakdown_enabled(self, value: bool) -> None: """ if value in [True, False]: - self._conn._telemetry_client.send_large_query_breakdown_telemetry( - self._session_id, value - ) - self._large_query_breakdown_enabled = value + with self._lock: + self._conn._telemetry_client.send_large_query_breakdown_telemetry( + self._session_id, value + ) + self._large_query_breakdown_enabled = value else: raise ValueError( "value for large_query_breakdown_enabled must be True or False!" @@ -806,7 +813,10 @@ def large_query_breakdown_enabled(self, value: bool) -> None: @custom_package_usage_config.setter @experimental_parameter(version="1.6.0") def custom_package_usage_config(self, config: Dict) -> None: - self._custom_package_usage_config = {k.lower(): v for k, v in config.items()} + with self._lock: + self._custom_package_usage_config = { + k.lower(): v for k, v in config.items() + } def cancel_all(self) -> None: """ @@ -1405,7 +1415,8 @@ def _get_dependency_packages( statement_params=statement_params, ) - custom_package_usage_config = self._custom_package_usage_config.copy() + with self._lock: + custom_package_usage_config = self._custom_package_usage_config.copy() unsupported_packages: List[str] = [] for package, package_info in package_dict.items(): From 7ae2c33a76d194fdf4b375eabc975a8aec3b90e6 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 11:45:45 -0700 Subject: [PATCH 17/62] fix SnowflakePlan copy --- .../snowpark/_internal/analyzer/snowflake_plan.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 25de5ba5718..c0b1f21eb9a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -354,8 +354,14 @@ def replace_repeated_subquery_with_cte( # create CTE query final_query = create_cte_query(self, duplicate_plan_set) + with self.session._lock: + # copy depends on the cte_optimization_enabled value. We should keep it + # consistent with the current context. + original_cte_optimization = self.session.cte_optimization_enabled + self.session.cte_optimization_enabled = cte_optimization_enabled + plan = copy.copy(self) + self.session.cte_optimization_enabled = original_cte_optimization # all other parts of query are unchanged, but just replace the original query - plan = copy.copy(self) plan.queries[-1].sql = final_query return plan @@ -451,7 +457,7 @@ def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): self._cumulative_node_complexity = value def __copy__(self) -> "SnowflakePlan": - if self.cte_optimization_enabled: + if self.session.cte_optimization_enabled: return SnowflakePlan( copy.deepcopy(self.queries) if self.queries else [], self.schema_query, From 1689ebffcdb60af1a61e833e75cb448f4c418328 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 11:47:57 -0700 Subject: [PATCH 18/62] minor update --- src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index c0b1f21eb9a..52fa940a782 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -457,7 +457,7 @@ def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]): self._cumulative_node_complexity = value def __copy__(self) -> "SnowflakePlan": - if self.session.cte_optimization_enabled: + if self.session._cte_optimization_enabled: return SnowflakePlan( copy.deepcopy(self.queries) if self.queries else [], self.schema_query, From 5e8a2d27e056d1e64428ae1245ff1735a6f2ff8a Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 11:49:52 -0700 Subject: [PATCH 19/62] add description --- src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 52fa940a782..a0679b3df97 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -536,7 +536,9 @@ def __init__( # on the optimized plan. During the final query generation, no schema query is needed, # this helps reduces un-necessary overhead for the describing call. self._skip_schema_query = skip_schema_query - # TODO: describe + # Value of cte_optimization_enabled and query_compilation_stage_enabled can change during + # resolution step. We need to cache the value at the beginning of resolve process and use + # the cached value during the plan build process. self._cte_optimization_enabled: Optional[bool] = None self._query_compilation_stage_enabled: Optional[bool] = None From 1c83ef232d297d366b5ae468e07760af4f0c352e Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:02:28 -0700 Subject: [PATCH 20/62] use _package_lock to protect Session._packages --- src/snowflake/snowpark/_internal/udf_utils.py | 2 +- src/snowflake/snowpark/session.py | 121 +++++++++--------- 2 files changed, 63 insertions(+), 60 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 07635c8de8a..bf4ce0af9b8 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -981,7 +981,7 @@ def add_snowpark_package_to_sproc_packages( if session is None: packages = [this_package] else: - with session._lock: + with session._package_lock: session_packages = session._packages.copy() if package_name not in session_packages: packages = list(session_packages.values()) + [this_package] diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a233ce26683..0cda2dc5185 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -502,6 +502,11 @@ def __init__( self._conn = conn self._thread_store = threading.local() self._lock = threading.RLock() + + # this lock is used to protect _packages. We use introduce a new lock because add_packages + # launches a query to snowflake to get all version of packages available in snowflake. This + # query can be slow and prevent other threads from moving on waiting for _lock. + self._package_lock = threading.RLock() self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} self._packages: Dict[str, str] = {} @@ -1116,7 +1121,7 @@ def get_packages(self) -> Dict[str, str]: The key of this ``dict`` is the package name and the value of this ``dict`` is the corresponding requirement specifier. """ - with self._lock: + with self._package_lock: return self._packages.copy() def add_packages( @@ -1208,7 +1213,7 @@ def remove_package(self, package: str) -> None: 0 """ package_name = pkg_resources.Requirement.parse(package).key - with self._lock: + with self._package_lock: if package_name in self._packages: self._packages.pop(package_name) else: @@ -1218,7 +1223,7 @@ def clear_packages(self) -> None: """ Clears all third-party packages of a user-defined function (UDF). """ - with self._lock: + with self._package_lock: self._packages.clear() def add_requirements(self, file_path: str) -> None: @@ -1567,25 +1572,26 @@ def _resolve_packages( if isinstance(self._conn, MockServerConnection): # in local testing we don't resolve the packages, we just return what is added errors = [] - with self._lock: - result_dict = self._packages.copy() - for pkg_name, _, pkg_req in package_dict.values(): - if pkg_name in result_dict and str(pkg_req) != result_dict[pkg_name]: - errors.append( - ValueError( - f"Cannot add package '{str(pkg_req)}' because {result_dict[pkg_name]} " - "is already added." + with self._package_lock: + result_dict = self._packages + for pkg_name, _, pkg_req in package_dict.values(): + if ( + pkg_name in result_dict + and str(pkg_req) != result_dict[pkg_name] + ): + errors.append( + ValueError( + f"Cannot add package '{str(pkg_req)}' because {result_dict[pkg_name]} " + "is already added." + ) ) - ) - else: - result_dict[pkg_name] = str(pkg_req) - if len(errors) == 1: - raise errors[0] - elif len(errors) > 0: - raise RuntimeError(errors) - - with self._lock: - self._packages.update(result_dict) + else: + result_dict[pkg_name] = str(pkg_req) + if len(errors) == 1: + raise errors[0] + elif len(errors) > 0: + raise RuntimeError(errors) + return list(result_dict.values()) package_table = "information_schema.packages" @@ -1600,50 +1606,47 @@ def _resolve_packages( # 'python-dateutil': 'python-dateutil==2.8.2'} # Add to packages dictionary. Make a copy of existing packages # dictionary to avoid modifying it during intermediate steps. - with self._lock: + with self._package_lock: result_dict = ( - existing_packages_dict.copy() - if existing_packages_dict is not None - else {} + existing_packages_dict if existing_packages_dict is not None else {} ) - # Retrieve list of dependencies that need to be added - dependency_packages = self._get_dependency_packages( - package_dict, - validate_package, - package_table, - result_dict, - statement_params=statement_params, - ) - - # Add dependency packages - for package in dependency_packages: - name = package.name - version = package.specs[0][1] if package.specs else None - - if name in result_dict: - if version is not None: - added_package_has_version = "==" in result_dict[name] - if added_package_has_version and result_dict[name] != str(package): - raise ValueError( - f"Cannot add dependency package '{name}=={version}' " - f"because {result_dict[name]} is already added." - ) + # Retrieve list of dependencies that need to be added + dependency_packages = self._get_dependency_packages( + package_dict, + validate_package, + package_table, + result_dict, + statement_params=statement_params, + ) + + # Add dependency packages + for package in dependency_packages: + name = package.name + version = package.specs[0][1] if package.specs else None + + if name in result_dict: + if version is not None: + added_package_has_version = "==" in result_dict[name] + if added_package_has_version and result_dict[name] != str( + package + ): + raise ValueError( + f"Cannot add dependency package '{name}=={version}' " + f"because {result_dict[name]} is already added." + ) + result_dict[name] = str(package) + else: result_dict[name] = str(package) - else: - result_dict[name] = str(package) - # Always include cloudpickle - extra_modules = [cloudpickle] - if include_pandas: - extra_modules.append("pandas") + # Always include cloudpickle + extra_modules = [cloudpickle] + if include_pandas: + extra_modules.append("pandas") - with self._lock: - if existing_packages_dict is not None: - existing_packages_dict.update(result_dict) - return list(result_dict.values()) + self._get_req_identifiers_list( - extra_modules, result_dict - ) + return list(result_dict.values()) + self._get_req_identifiers_list( + extra_modules, result_dict + ) def _upload_unsupported_packages( self, From a6497610be5072357c54b5c64bb8df95c65ef165 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:14:21 -0700 Subject: [PATCH 21/62] undo refactor --- src/snowflake/snowpark/_internal/udf_utils.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index bf4ce0af9b8..58d698556b3 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -982,9 +982,8 @@ def add_snowpark_package_to_sproc_packages( packages = [this_package] else: with session._package_lock: - session_packages = session._packages.copy() - if package_name not in session_packages: - packages = list(session_packages.values()) + [this_package] + if package_name not in session._packages: + packages = list(session._packages.values()) + [this_package] else: package_names = [p if isinstance(p, str) else p.__name__ for p in packages] if not any(p.startswith(package_name) for p in package_names): @@ -1076,20 +1075,19 @@ def resolve_imports_and_packages( ) ) - all_urls = [] if session is not None: import_only_stage = ( unwrap_stage_location_single_quote(stage_location) if stage_location else session.get_session_stage(statement_params=statement_params) ) - upload_and_import_stage = ( import_only_stage if is_permanent else session.get_session_stage(statement_params=statement_params) ) + if session: if imports: udf_level_imports = {} for udf_import in imports: @@ -1117,15 +1115,22 @@ def resolve_imports_and_packages( upload_and_import_stage, statement_params=statement_params, ) + else: + all_urls = [] + else: + all_urls = [] dest_prefix = get_udf_upload_prefix(udf_name) # Upload closure to stage if it is beyond inline closure size limit handler = inline_code = upload_file_stage_location = None - # As cloudpickle is being used, we cannot allow a custom runtime - custom_python_runtime_version_allowed = not isinstance(func, Callable) + custom_python_runtime_version_allowed = False if session is not None: if isinstance(func, Callable): + custom_python_runtime_version_allowed = ( + False # As cloudpickle is being used, we cannot allow a custom runtime + ) + # generate a random name for udf py file # and we compress it first then upload it udf_file_name_base = f"udf_py_{random_number()}" @@ -1170,6 +1175,7 @@ def resolve_imports_and_packages( upload_file_stage_location = None handler = _DEFAULT_HANDLER_NAME else: + custom_python_runtime_version_allowed = True udf_file_name = os.path.basename(func[0]) # for a compressed file, it might have multiple extensions # and we should remove all extensions @@ -1194,6 +1200,11 @@ def resolve_imports_and_packages( skip_upload_on_content_match=skip_upload_on_content_match, ) all_urls.append(upload_file_stage_location) + else: + if isinstance(func, Callable): + custom_python_runtime_version_allowed = False + else: + custom_python_runtime_version_allowed = True # build imports and packages string all_imports = ",".join( From f03d6186f84b4f5b593d1920d820e218347e3e3a Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:15:50 -0700 Subject: [PATCH 22/62] undo refactor --- src/snowflake/snowpark/_internal/udf_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 58d698556b3..25921fff821 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1081,6 +1081,7 @@ def resolve_imports_and_packages( if stage_location else session.get_session_stage(statement_params=statement_params) ) + upload_and_import_stage = ( import_only_stage if is_permanent From 5f398d5f00edb42ca3f19d656f656fdc91a9c99c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:33:07 -0700 Subject: [PATCH 23/62] fix test --- tests/unit/test_udf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index c23755c14a3..09e389d0c24 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -250,7 +250,7 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): "random_package_one": "random_package_one", "random_package_two": "random_package_two", } - fake_session._lock = threading.RLock() + fake_session._package_lock = threading.RLock() result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) major, minor, patch = VERSION From 380708778a6678f9f2bdbbf9a6e95a3b5074bf19 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:36:20 -0700 Subject: [PATCH 24/62] fix test --- tests/unit/test_server_connection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_server_connection.py b/tests/unit/test_server_connection.py index 72ccb6f6c42..cf10dc9d29a 100644 --- a/tests/unit/test_server_connection.py +++ b/tests/unit/test_server_connection.py @@ -119,6 +119,7 @@ def test_get_result_set_exception(mock_server_connection): fake_session._last_canceled_id = 100 fake_session._conn = mock_server_connection fake_session._cte_optimization_enabled = False + fake_session._query_compilation_stage_enabled = False fake_plan = SnowflakePlan( queries=[Query("fake query 1"), Query("fake query 2")], schema_query="fake schema query", From df3263c20c50beabe3f861c3db502ff2c8830c34 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 15:06:48 -0700 Subject: [PATCH 25/62] add file IO tests --- tests/integ/conftest.py | 23 ++++++- .../integ/scala/test_file_operation_suite.py | 17 +---- tests/integ/test_multithreading.py | 69 ++++++++++++++++++- 3 files changed, 91 insertions(+), 18 deletions(-) diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index ec619605e66..319abb137f4 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -13,7 +13,13 @@ from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.mock._connection import MockServerConnection from tests.parameters import CONNECTION_PARAMETERS -from tests.utils import TEST_SCHEMA, Utils, running_on_jenkins, running_on_public_ci +from tests.utils import ( + TEST_SCHEMA, + TestFiles, + Utils, + running_on_jenkins, + running_on_public_ci, +) def print_help() -> None: @@ -235,3 +241,18 @@ def temp_schema(connection, session, local_testing_mode) -> None: ) yield temp_schema_name cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}") + + +@pytest.fixture(scope="module") +def temp_stage(session, resources_path, local_testing_mode): + tmp_stage_name = Utils.random_stage_name() + test_files = TestFiles(resources_path) + + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) + Utils.upload_to_stage( + session, tmp_stage_name, test_files.test_file_parquet, compress=False + ) + yield tmp_stage_name + if not local_testing_mode: + Utils.drop_stage(session, tmp_stage_name) diff --git a/tests/integ/scala/test_file_operation_suite.py b/tests/integ/scala/test_file_operation_suite.py index 2dc424dde09..82a1722a729 100644 --- a/tests/integ/scala/test_file_operation_suite.py +++ b/tests/integ/scala/test_file_operation_suite.py @@ -14,7 +14,7 @@ SnowparkSQLException, SnowparkUploadFileException, ) -from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, TestFiles, Utils +from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, Utils def random_alphanumeric_name(): @@ -74,21 +74,6 @@ def path4(temp_source_directory): yield filename -@pytest.fixture(scope="module") -def temp_stage(session, resources_path, local_testing_mode): - tmp_stage_name = Utils.random_stage_name() - test_files = TestFiles(resources_path) - - if not local_testing_mode: - Utils.create_stage(session, tmp_stage_name, is_temporary=True) - Utils.upload_to_stage( - session, tmp_stage_name, test_files.test_file_parquet, compress=False - ) - yield tmp_stage_name - if not local_testing_mode: - Utils.drop_stage(session, tmp_stage_name) - - def test_put_with_one_file( session, temp_stage, path1, path2, path3, local_testing_mode ): diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 164a4b7b590..10fcc6ef70d 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -2,6 +2,9 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import hashlib +import os +import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed from unittest.mock import patch @@ -9,7 +12,7 @@ from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row -from tests.utils import IS_IN_STORED_PROC, Utils +from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils def test_concurrent_select_queries(session): @@ -122,3 +125,67 @@ def test_action_ids_are_unique(session): action_ids.add(future.result()) assert len(action_ids) == 10 + + +@pytest.mark.parametrize("use_stream", [True, False]) +def test_file_io(session, resources_path, temp_stage, use_stream): + stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}" + stage_with_prefix = f"@{temp_stage}/{stage_prefix}/" + test_files = TestFiles(resources_path) + + resources_files = [ + test_files.test_file_csv, + test_files.test_file2_csv, + test_files.test_file_json, + test_files.test_file_csv_header, + test_files.test_file_csv_colon, + test_files.test_file_csv_quotes, + test_files.test_file_csv_special_format, + test_files.test_file_json_special_format, + test_files.test_file_csv_quotes_special, + test_files.test_concat_file1_csv, + test_files.test_concat_file2_csv, + ] + + def get_file_hash(fd): + return hashlib.md5(fd.read()).hexdigest() + + def put_and_get_file(upload_file_path, download_dir): + if use_stream: + with open(upload_file_path, "rb") as fd: + results = session.file.put_stream( + fd, stage_with_prefix, auto_compress=False, overwrite=False + ) + else: + results = session.file.put( + upload_file_path, + stage_with_prefix, + auto_compress=False, + overwrite=False, + ) + # assert file is uploaded successfully + assert len(results) == 1 + assert results[0].status == "UPLOADED" + + stage_file_name = f"{stage_with_prefix}{os.path.basename(upload_file_path)}" + if use_stream: + fd = session.file.get_stream(stage_file_name, download_dir) + with open(upload_file_path, "rb") as upload_fd: + assert get_file_hash(upload_fd) == get_file_hash(fd) + + else: + results = session.file.get(stage_file_name, download_dir) + # assert file is downloaded successfully + assert len(results) == 1 + assert results[0].status == "DOWNLOADED" + download_file_path = results[0].file + # assert two files are identical + with open(upload_file_path, "rb") as upload_fd, open( + download_file_path, "rb" + ) as download_fd: + assert get_file_hash(upload_fd) == get_file_hash(download_fd) + + with tempfile.TemporaryDirectory() as download_dir: + with ThreadPoolExecutor(max_workers=10) as executor: + for file_path in resources_files: + executor.submit(put_and_get_file, file_path, download_dir) From a737f33a1abfcdc5c1457703c5a46f98307fc9f2 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 15:24:27 -0700 Subject: [PATCH 26/62] fix test --- tests/unit/compiler/test_large_query_breakdown.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/unit/compiler/test_large_query_breakdown.py b/tests/unit/compiler/test_large_query_breakdown.py index d040ca25f49..5c9e140694f 100644 --- a/tests/unit/compiler/test_large_query_breakdown.py +++ b/tests/unit/compiler/test_large_query_breakdown.py @@ -105,7 +105,12 @@ ], ) def test_pipeline_breaker_node(mock_session, mock_analyzer, node_generator, expected): - large_query_breakdown = LargeQueryBreakdown(mock_session, mock_analyzer, []) + large_query_breakdown = LargeQueryBreakdown( + mock_session, + mock_analyzer, + [], + mock_session.large_query_breakdown_complexity_bounds, + ) node = node_generator(mock_analyzer) assert ( From 8ca2730c5567862e3af61bfad98974efaf84e227 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 16:38:20 -0700 Subject: [PATCH 27/62] protect complexity bounds setter with lock --- src/snowflake/snowpark/session.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 4f42799dde9..9c0f629f46a 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -867,11 +867,12 @@ def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> Non raise ValueError( f"Expecting a tuple of lower and upper bound with the lower bound less than the upper bound. Got (lower, upper) = ({value[0], value[1]})" ) - self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds( - self._session_id, value[0], value[1] - ) + with self._lock: + self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds( + self._session_id, value[0], value[1] + ) - self._large_query_breakdown_complexity_bounds = value + self._large_query_breakdown_complexity_bounds = value @custom_package_usage_config.setter @experimental_parameter(version="1.6.0") From 81417a382724253667372b0174e776c97375bc00 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 19 Sep 2024 16:30:24 -0700 Subject: [PATCH 28/62] add config context --- .../snowpark/_internal/analyzer/analyzer.py | 202 +++++++++--------- .../_internal/analyzer/snowflake_plan.py | 92 +++++--- tests/integ/test_multithreading.py | 46 ++++ 3 files changed, 210 insertions(+), 130 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 5022cae4d02..68314a2e6b2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -87,6 +87,7 @@ SelectTableFunction, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( + ConfigContext, SnowflakePlan, SnowflakePlanBuilder, ) @@ -166,25 +167,33 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None: self.generated_alias_maps = {} self.subquery_plans = [] self.alias_maps_to_use: Optional[Dict[uuid.UUID, str]] = None - self._eliminate_numeric_sql_value_cast_enabled: Optional[bool] = None - - @property - def eliminate_numeric_sql_value_cast_enabled(self) -> bool: - if self._eliminate_numeric_sql_value_cast_enabled is None: - return self.session.eliminate_numeric_sql_value_cast_enabled - return self._eliminate_numeric_sql_value_cast_enabled + self.config_context: ConfigContext = ConfigContext(session) + # Point this plan builder to the same config context as the analyzer + self.plan_builder.config_context = self.config_context def analyze( self, expr: Union[Expression, NamedExpression], df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], parse_local_name=False, + ) -> str: + # Set the config context for analysis step + with self.config_context: + return self.do_analyze( + expr, df_aliased_col_name_to_real_col_name, parse_local_name + ) + + def do_analyze( + self, + expr: Union[Expression, NamedExpression], + df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], + parse_local_name=False, ) -> str: if isinstance(expr, GroupingSetsExpression): return grouping_set_expression( [ [ - self.analyze( + self.do_analyze( a, df_aliased_col_name_to_real_col_name, parse_local_name ) for a in arg @@ -195,23 +204,23 @@ def analyze( if isinstance(expr, Like): return like_expression( - self.analyze( + self.do_analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), - self.analyze( + self.do_analyze( expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name ), ) if isinstance(expr, RegExp): return regexp_expression( - self.analyze( + self.do_analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), - self.analyze( + self.do_analyze( expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name ), - self.analyze( + self.do_analyze( expr.parameters, df_aliased_col_name_to_real_col_name, parse_local_name, @@ -225,7 +234,7 @@ def analyze( expr.collation_spec.upper() if parse_local_name else expr.collation_spec ) return collate_expression( - self.analyze( + self.do_analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), collation_spec, @@ -236,7 +245,7 @@ def analyze( if parse_local_name and isinstance(field, str): field = field.upper() return subfield_expression( - self.analyze( + self.do_analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), field, @@ -246,12 +255,12 @@ def analyze( return case_when_expression( [ ( - self.analyze( + self.do_analyze( condition, df_aliased_col_name_to_real_col_name, parse_local_name, ), - self.analyze( + self.do_analyze( value, df_aliased_col_name_to_real_col_name, parse_local_name, @@ -259,7 +268,7 @@ def analyze( ) for condition, value in expr.branches ], - self.analyze( + self.do_analyze( expr.else_value, df_aliased_col_name_to_real_col_name, parse_local_name, @@ -271,14 +280,14 @@ def analyze( if isinstance(expr, MultipleExpression): block_expressions = [] for expression in expr.expressions: - if self.eliminate_numeric_sql_value_cast_enabled: + if self.config_context.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, parse_local_name, ) else: - resolved_expr = self.analyze( + resolved_expr = self.do_analyze( expression, df_aliased_col_name_to_real_col_name, parse_local_name, @@ -290,14 +299,14 @@ def analyze( if isinstance(expr, InExpression): in_values = [] for expression in expr.values: - if self.eliminate_numeric_sql_value_cast_enabled: + if self.config_context.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, parse_local_name, ) else: - in_value = self.analyze( + in_value = self.do_analyze( expression, df_aliased_col_name_to_real_col_name, parse_local_name, @@ -305,7 +314,7 @@ def analyze( in_values.append(in_value) return in_expression( - self.analyze( + self.do_analyze( expr.columns, df_aliased_col_name_to_real_col_name, parse_local_name ), in_values, @@ -316,12 +325,12 @@ def analyze( if isinstance(expr, WindowExpression): return window_expression( - self.analyze( + self.do_analyze( expr.window_function, df_aliased_col_name_to_real_col_name, parse_local_name, ), - self.analyze( + self.do_analyze( expr.window_spec, df_aliased_col_name_to_real_col_name, parse_local_name, @@ -330,18 +339,18 @@ def analyze( if isinstance(expr, WindowSpecDefinition): return window_spec_expression( [ - self.analyze( + self.do_analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.partition_spec ], [ - self.analyze( + self.do_analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.order_spec ], - self.analyze( + self.do_analyze( expr.frame_spec, df_aliased_col_name_to_real_col_name, parse_local_name, @@ -422,7 +431,7 @@ def analyze( # This case is hit by df.col("*") return ",".join( [ - self.analyze(e, df_aliased_col_name_to_real_col_name) + self.do_analyze(e, df_aliased_col_name_to_real_col_name) for e in expr.expressions ] ) @@ -436,7 +445,7 @@ def analyze( return function_expression( func_name, [ - self.analyze( + self.do_analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.children @@ -457,7 +466,7 @@ def analyze( return table_function_partition_spec( expr.over, [ - self.analyze( + self.do_analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.partition_spec @@ -465,7 +474,7 @@ def analyze( if expr.partition_spec else [], [ - self.analyze( + self.do_analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.order_spec @@ -481,7 +490,7 @@ def analyze( if isinstance(expr, SortOrder): return order_expression( - self.analyze( + self.do_analyze( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.direction.sql, @@ -494,11 +503,11 @@ def analyze( if isinstance(expr, WithinGroup): return within_group_expression( - self.analyze( + self.do_analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), [ - self.analyze(e, df_aliased_col_name_to_real_col_name) + self.do_analyze(e, df_aliased_col_name_to_real_col_name) for e in expr.order_by_cols ], ) @@ -510,42 +519,42 @@ def analyze( if isinstance(expr, InsertMergeExpression): return insert_merge_statement( - self.analyze(expr.condition, df_aliased_col_name_to_real_col_name) + self.do_analyze(expr.condition, df_aliased_col_name_to_real_col_name) if expr.condition else None, [ - self.analyze(k, df_aliased_col_name_to_real_col_name) + self.do_analyze(k, df_aliased_col_name_to_real_col_name) for k in expr.keys ], [ - self.analyze(v, df_aliased_col_name_to_real_col_name) + self.do_analyze(v, df_aliased_col_name_to_real_col_name) for v in expr.values ], ) if isinstance(expr, UpdateMergeExpression): return update_merge_statement( - self.analyze(expr.condition, df_aliased_col_name_to_real_col_name) + self.do_analyze(expr.condition, df_aliased_col_name_to_real_col_name) if expr.condition else None, { - self.analyze(k, df_aliased_col_name_to_real_col_name): self.analyze( - v, df_aliased_col_name_to_real_col_name - ) + self.do_analyze( + k, df_aliased_col_name_to_real_col_name + ): self.do_analyze(v, df_aliased_col_name_to_real_col_name) for k, v in expr.assignments.items() }, ) if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( - self.analyze(expr.condition, df_aliased_col_name_to_real_col_name) + self.do_analyze(expr.condition, df_aliased_col_name_to_real_col_name) if expr.condition else None ) if isinstance(expr, ListAgg): return list_agg( - self.analyze( + self.do_analyze( expr.col, df_aliased_col_name_to_real_col_name, parse_local_name ), str_to_sql(expr.delimiter), @@ -555,7 +564,7 @@ def analyze( if isinstance(expr, ColumnSum): return column_sum( [ - self.analyze( + self.do_analyze( col, df_aliased_col_name_to_real_col_name, parse_local_name ) for col in expr.exprs @@ -565,11 +574,11 @@ def analyze( if isinstance(expr, RankRelatedFunctionExpression): return rank_related_function_expression( expr.sql, - self.analyze( + self.do_analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.offset, - self.analyze( + self.do_analyze( expr.default, df_aliased_col_name_to_real_col_name, parse_local_name ) if expr.default @@ -589,7 +598,7 @@ def table_function_expression_extractor( ) -> str: if isinstance(expr, FlattenFunction): return flatten_expression( - self.analyze( + self.do_analyze( expr.input, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.path, @@ -601,7 +610,7 @@ def table_function_expression_extractor( sql = function_expression( expr.func_name, [ - self.analyze( + self.do_analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.args @@ -624,7 +633,7 @@ def table_function_expression_extractor( "NamedArgumentsTableFunction, GeneratorTableFunction, or FlattenFunction." ) partition_spec_sql = ( - self.analyze(expr.partition_spec, df_aliased_col_name_to_real_col_name) + self.do_analyze(expr.partition_spec, df_aliased_col_name_to_real_col_name) if expr.partition_spec else "" ) @@ -650,13 +659,13 @@ def unary_expression_extractor( if v == expr.child.name: df_alias_dict[k] = quoted_name return alias_expression( - self.analyze( + self.do_analyze( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ), quoted_name, ) if isinstance(expr, UnresolvedAlias): - expr_str = self.analyze( + expr_str = self.do_analyze( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ) if parse_local_name: @@ -664,7 +673,7 @@ def unary_expression_extractor( return expr_str elif isinstance(expr, Cast): return cast_expression( - self.analyze( + self.do_analyze( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.to, @@ -672,7 +681,7 @@ def unary_expression_extractor( ) else: return unary_expression( - self.analyze( + self.do_analyze( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.sql_operator, @@ -685,7 +694,7 @@ def binary_operator_extractor( df_aliased_col_name_to_real_col_name, parse_local_name=False, ) -> str: - if self.eliminate_numeric_sql_value_cast_enabled: + if self.config_context.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( expr.left, df_aliased_col_name_to_real_col_name, parse_local_name ) @@ -695,10 +704,10 @@ def binary_operator_extractor( parse_local_name, ) else: - left_sql_expr = self.analyze( + left_sql_expr = self.do_analyze( expr.left, df_aliased_col_name_to_real_col_name, parse_local_name ) - right_sql_expr = self.analyze( + right_sql_expr = self.do_analyze( expr.right, df_aliased_col_name_to_real_col_name, parse_local_name ) if isinstance(expr, BinaryArithmeticExpression): @@ -720,7 +729,7 @@ def binary_operator_extractor( def grouping_extractor( self, expr: GroupingSet, df_aliased_col_name_to_real_col_name ) -> str: - return self.analyze( + return self.do_analyze( FunctionExpression( expr.pretty_name.upper(), [c.child if isinstance(c, Alias) else c for c in expr.children], @@ -759,7 +768,7 @@ def to_sql_try_avoid_cast( ): return str(expr.value).upper() else: - return self.analyze( + return self.do_analyze( expr, df_aliased_col_name_to_real_col_name, parse_local_name ) @@ -767,29 +776,14 @@ def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: self.subquery_plans = [] self.generated_alias_maps = {} - # To ensure that the context remain unchanged during resolving the plan, we - # read these values at the beginning and reset them at the end. - self.plan_builder._cte_optimization_enabled = ( - self.session.cte_optimization_enabled - ) - self.plan_builder._query_compilation_stage_enabled = ( - self.session._query_compilation_stage_enabled - ) - self._eliminate_numeric_sql_value_cast_enabled = ( - self.session.eliminate_numeric_sql_value_cast_enabled - ) - - result = self.do_resolve(logical_plan) + with self.config_context: + result = self.do_resolve(logical_plan) result.add_aliases(self.generated_alias_maps) if self.subquery_plans: result = result.with_subqueries(self.subquery_plans) - self.plan_builder._cte_optimization_enabled = None - self.plan_builder._query_compilation_stage_enabled = None - self._eliminate_numeric_sql_value_cast_enabled = None - return result def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: @@ -845,7 +839,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableFunctionJoin): return self.plan_builder.join_table_function( - self.analyze( + self.do_analyze( logical_plan.table_function, df_aliased_col_name_to_real_col_name ), resolved_children[logical_plan.children[0]], @@ -857,7 +851,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableFunctionRelation): return self.plan_builder.from_table_function( - self.analyze( + self.do_analyze( logical_plan.table_function, df_aliased_col_name_to_real_col_name ), logical_plan, @@ -865,7 +859,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Lateral): return self.plan_builder.lateral( - self.analyze( + self.do_analyze( logical_plan.table_function, df_aliased_col_name_to_real_col_name ), resolved_children[logical_plan.children[0]], @@ -881,7 +875,7 @@ def do_resolve_with_resolved_children( for expr in logical_plan.grouping_expressions ], [ - self.analyze(expr, df_aliased_col_name_to_real_col_name) + self.do_analyze(expr, df_aliased_col_name_to_real_col_name) for expr in logical_plan.aggregate_expressions ], resolved_children[logical_plan.child], @@ -892,7 +886,9 @@ def do_resolve_with_resolved_children( return self.plan_builder.project( list( map( - lambda x: self.analyze(x, df_aliased_col_name_to_real_col_name), + lambda x: self.do_analyze( + x, df_aliased_col_name_to_real_col_name + ), logical_plan.project_list, ) ), @@ -902,7 +898,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Filter): return self.plan_builder.filter( - self.analyze( + self.do_analyze( logical_plan.condition, df_aliased_col_name_to_real_col_name ), resolved_children[logical_plan.child], @@ -920,14 +916,14 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Join): join_condition = ( - self.analyze( + self.do_analyze( logical_plan.join_condition, df_aliased_col_name_to_real_col_name ) if logical_plan.join_condition else "" ) match_condition = ( - self.analyze( + self.do_analyze( logical_plan.match_condition, df_aliased_col_name_to_real_col_name ) if logical_plan.match_condition @@ -946,7 +942,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Sort): return self.plan_builder.sort( [ - self.analyze(x, df_aliased_col_name_to_real_col_name) + self.do_analyze(x, df_aliased_col_name_to_real_col_name) for x in logical_plan.order ], resolved_children[logical_plan.child], @@ -1010,7 +1006,7 @@ def do_resolve_with_resolved_children( mode=logical_plan.mode, table_type=logical_plan.table_type, clustering_keys=[ - self.analyze(x, df_aliased_col_name_to_real_col_name) + self.do_analyze(x, df_aliased_col_name_to_real_col_name) for x in logical_plan.clustering_exprs ], comment=logical_plan.comment, @@ -1064,7 +1060,7 @@ def do_resolve_with_resolved_children( ] child = self.plan_builder.project( [ - self.analyze(col, df_aliased_col_name_to_real_col_name) + self.do_analyze(col, df_aliased_col_name_to_real_col_name) for col in project_exprs ], resolved_children[logical_plan.child], @@ -1080,25 +1076,25 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan.pivot_values, List): pivot_values = [ - self.analyze(pv, df_aliased_col_name_to_real_col_name) + self.do_analyze(pv, df_aliased_col_name_to_real_col_name) for pv in logical_plan.pivot_values ] elif isinstance(logical_plan.pivot_values, ScalarSubquery): - pivot_values = self.analyze( + pivot_values = self.do_analyze( logical_plan.pivot_values, df_aliased_col_name_to_real_col_name ) else: pivot_values = None pivot_plan = self.plan_builder.pivot( - self.analyze( + self.do_analyze( logical_plan.pivot_column, df_aliased_col_name_to_real_col_name ), pivot_values, - self.analyze( + self.do_analyze( logical_plan.aggregates[0], df_aliased_col_name_to_real_col_name ), - self.analyze( + self.do_analyze( logical_plan.default_on_null, df_aliased_col_name_to_real_col_name ) if logical_plan.default_on_null @@ -1124,7 +1120,7 @@ def do_resolve_with_resolved_children( logical_plan.value_column, logical_plan.name_column, [ - self.analyze(c, df_aliased_col_name_to_real_col_name) + self.do_analyze(c, df_aliased_col_name_to_real_col_name) for c in logical_plan.column_list ], resolved_children[logical_plan.child], @@ -1166,7 +1162,7 @@ def do_resolve_with_resolved_children( refresh_mode=logical_plan.refresh_mode, initialize=logical_plan.initialize, clustering_keys=[ - self.analyze(x, df_aliased_col_name_to_real_col_name) + self.do_analyze(x, df_aliased_col_name_to_real_col_name) for x in logical_plan.clustering_exprs ], is_transient=logical_plan.is_transient, @@ -1198,7 +1194,7 @@ def do_resolve_with_resolved_children( validation_mode=logical_plan.validation_mode, column_names=logical_plan.column_names, transformations=[ - self.analyze(x, df_aliased_col_name_to_real_col_name) + self.do_analyze(x, df_aliased_col_name_to_real_col_name) for x in logical_plan.transformations ] if logical_plan.transformations @@ -1213,7 +1209,7 @@ def do_resolve_with_resolved_children( query=resolved_children[logical_plan.child], stage_location=logical_plan.stage_location, source_plan=logical_plan, - partition_by=self.analyze( + partition_by=self.do_analyze( logical_plan.partition_by, df_aliased_col_name_to_real_col_name ) if logical_plan.partition_by @@ -1229,12 +1225,12 @@ def do_resolve_with_resolved_children( return self.plan_builder.update( logical_plan.table_name, { - self.analyze(k, df_aliased_col_name_to_real_col_name): self.analyze( - v, df_aliased_col_name_to_real_col_name - ) + self.do_analyze( + k, df_aliased_col_name_to_real_col_name + ): self.do_analyze(v, df_aliased_col_name_to_real_col_name) for k, v in logical_plan.assignments.items() }, - self.analyze( + self.do_analyze( logical_plan.condition, df_aliased_col_name_to_real_col_name ) if logical_plan.condition @@ -1248,7 +1244,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableDelete): return self.plan_builder.delete( logical_plan.table_name, - self.analyze( + self.do_analyze( logical_plan.condition, df_aliased_col_name_to_real_col_name ) if logical_plan.condition @@ -1266,11 +1262,11 @@ def do_resolve_with_resolved_children( resolved_children[logical_plan.source] if logical_plan.source else logical_plan.source, - self.analyze( + self.do_analyze( logical_plan.join_expr, df_aliased_col_name_to_real_col_name ), [ - self.analyze(c, df_aliased_col_name_to_real_col_name) + self.do_analyze(c, df_aliased_col_name_to_real_col_name) for c in logical_plan.clauses ], logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index a0679b3df97..3ca92c006df 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -523,6 +523,49 @@ def add_aliases(self, to_add: Dict) -> None: self.expr_to_alias = {**self.expr_to_alias, **to_add} +class ConfigContext: + """Class to manage the snapshot of configuration settings in the context of plan + building, analysis and resolution. Inside an active context, the configuration will + be read from the session object and reset when the context is exited. When no active + context is present, the configuration will be read from the session object directly. + + Supported configs are stored in the `configs` attribute which is a dict of ConfigContext + """ + + def __init__(self, session) -> None: + self.session = session + self.configs = { + "cte_optimization_enabled", + "_query_compilation_stage_enabled", + "eliminate_numeric_sql_value_cast_enabled", + } + self.reset() + + def __getattr__(self, name: str) -> Any: + if name in self.configs: + return getattr(self, name) or getattr(self.session, name) + return AttributeError(f"ConfigContext has no attribute {name}") + + def __enter__(self) -> "ConfigContext": + self.create_snapshot() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.reset() + + def create_snapshot(self) -> "ConfigContext": + """Reads the configuration settings from the session object and stores them in the + context object. + """ + for name in self.configs: + setattr(self, name, getattr(self.session, name)) + return self + + def reset(self) -> None: + for name in self.configs: + setattr(self, name, None) + + class SnowflakePlanBuilder: def __init__( self, @@ -536,23 +579,7 @@ def __init__( # on the optimized plan. During the final query generation, no schema query is needed, # this helps reduces un-necessary overhead for the describing call. self._skip_schema_query = skip_schema_query - # Value of cte_optimization_enabled and query_compilation_stage_enabled can change during - # resolution step. We need to cache the value at the beginning of resolve process and use - # the cached value during the plan build process. - self._cte_optimization_enabled: Optional[bool] = None - self._query_compilation_stage_enabled: Optional[bool] = None - - @property - def cte_optimization_enabled(self) -> bool: - if self._cte_optimization_enabled is None: - return self.session.cte_optimization_enabled - return self._cte_optimization_enabled - - @property - def query_compilation_stage_enabled(self) -> bool: - if self._query_compilation_stage_enabled is None: - return self.session._query_compilation_stage_enabled - return self._query_compilation_stage_enabled + self.config_context: ConfigContext = ConfigContext(session) @SnowflakePlan.Decorator.wrap_exception def build( @@ -588,7 +615,8 @@ def build( placeholder_query = ( sql_generator(select_child._id) - if self.cte_optimization_enabled and select_child._id is not None + if self.config_context.cte_optimization_enabled + and select_child._id is not None else None ) @@ -627,7 +655,7 @@ def build_binary( placeholder_query = ( sql_generator(select_left._id, select_right._id) - if self.cte_optimization_enabled + if self.config_context.cte_optimization_enabled and select_left._id is not None and select_right._id is not None else None @@ -658,7 +686,10 @@ def build_binary( post_actions.append(copy.copy(post_action)) referenced_ctes: Set[str] = set() - if self.cte_optimization_enabled and self.query_compilation_stage_enabled: + if ( + self.config_context.cte_optimization_enabled + and self.config_context._query_compilation_stage_enabled + ): # When the cte optimization and the new compilation stage is enabled, # the referred cte tables are propagated from left and right can have # duplicated queries if there is a common CTE block referenced by @@ -948,7 +979,8 @@ def save_as_table( ) child = child.replace_repeated_subquery_with_cte( - self.cte_optimization_enabled, self.query_compilation_stage_enabled + self.config_context.cte_optimization_enabled, + self.config_context._query_compilation_stage_enabled, ) def get_create_table_as_select_plan(child: SnowflakePlan, replace, error): @@ -1138,7 +1170,8 @@ def create_or_replace_view( raise SnowparkClientExceptionMessages.PLAN_CREATE_VIEWS_FROM_SELECT_ONLY() child = child.replace_repeated_subquery_with_cte( - self.cte_optimization_enabled, self.query_compilation_stage_enabled + self.config_context.cte_optimization_enabled, + self.config_context._query_compilation_stage_enabled, ) return self.build( lambda x: create_or_replace_view_statement(name, x, is_temp, comment), @@ -1183,7 +1216,8 @@ def create_or_replace_dynamic_table( raise ValueError(f"Unknown create mode: {create_mode}") # pragma: no cover child = child.replace_repeated_subquery_with_cte( - self.cte_optimization_enabled, self.query_compilation_stage_enabled + self.config_context.cte_optimization_enabled, + self.config_context._query_compilation_stage_enabled, ) return self.build( lambda x: create_or_replace_dynamic_table_statement( @@ -1488,7 +1522,8 @@ def copy_into_location( **copy_options: Optional[Any], ) -> SnowflakePlan: query = query.replace_repeated_subquery_with_cte( - self.cte_optimization_enabled, self.query_compilation_stage_enabled + self.config_context.cte_optimization_enabled, + self.config_context._query_compilation_stage_enabled, ) return self.build( lambda x: copy_into_location( @@ -1517,7 +1552,8 @@ def update( ) -> SnowflakePlan: if source_data: source_data = source_data.replace_repeated_subquery_with_cte( - self.cte_optimization_enabled, self.query_compilation_stage_enabled + self.config_context.cte_optimization_enabled, + self.config_context._query_compilation_stage_enabled, ) return self.build( lambda x: update_statement( @@ -1550,7 +1586,8 @@ def delete( ) -> SnowflakePlan: if source_data: source_data = source_data.replace_repeated_subquery_with_cte( - self.cte_optimization_enabled, self.query_compilation_stage_enabled + self.config_context.cte_optimization_enabled, + self.config_context._query_compilation_stage_enabled, ) return self.build( lambda x: delete_statement( @@ -1581,7 +1618,8 @@ def merge( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: source_data = source_data.replace_repeated_subquery_with_cte( - self.cte_optimization_enabled, self.query_compilation_stage_enabled + self.config_context.cte_optimization_enabled, + self.config_context._query_compilation_stage_enabled, ) return self.build( lambda x: merge_statement(table_name, x, join_expr, clauses), diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 9f8b5bef293..1efd3449b5f 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -11,6 +11,7 @@ import pytest +from snowflake.snowpark._internal.analyzer.snowflake_plan import ConfigContext from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType @@ -496,3 +497,48 @@ def finish(self): with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): executor.submit(register_and_test_udaf, session, i) + + +def test_config_context(session): + config_context = ConfigContext(session) + for name in config_context.configs: + assert hasattr(session, name) + + try: + original_cte_optimization = session.cte_optimization_enabled + original_eliminate_numeric_sql_value_cast_enabled = ( + session.eliminate_numeric_sql_value_cast_enabled + ) + original_query_compilation_stage_enabled = ( + session._query_compilation_stage_enabled + ) + with config_context: + session.cte_optimization_enabled = not original_cte_optimization + session.eliminate_numeric_sql_value_cast_enabled = ( + not original_eliminate_numeric_sql_value_cast_enabled + ) + session._query_compilation_stage_enabled = ( + not original_query_compilation_stage_enabled + ) + + assert config_context.cte_optimization_enabled == original_cte_optimization + assert ( + config_context.eliminate_numeric_sql_value_cast_enabled + == original_eliminate_numeric_sql_value_cast_enabled + ) + assert ( + config_context._query_compilation_stage_enabled + == original_query_compilation_stage_enabled + ) + + assert config_context.cte_optimization_enabled is None + assert config_context.eliminate_numeric_sql_value_cast_enabled is None + assert config_context._query_compilation_stage_enabled is None + finally: + session.cte_optimization_enabled = original_cte_optimization + session.eliminate_numeric_sql_value_cast_enabled = ( + original_eliminate_numeric_sql_value_cast_enabled + ) + session._query_compilation_stage_enabled = ( + original_query_compilation_stage_enabled + ) From e340567e3fd5eb8360a6174b8109ae21975337d9 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 19 Sep 2024 16:56:24 -0700 Subject: [PATCH 29/62] add tests --- .../_internal/analyzer/snowflake_plan.py | 25 ++++++++++--------- tests/integ/test_multithreading.py | 19 +++++++++++--- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 3ca92c006df..5210d244ea3 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -525,11 +525,13 @@ def add_aliases(self, to_add: Dict) -> None: class ConfigContext: """Class to manage the snapshot of configuration settings in the context of plan - building, analysis and resolution. Inside an active context, the configuration will - be read from the session object and reset when the context is exited. When no active - context is present, the configuration will be read from the session object directly. + building, analysis and resolution. - Supported configs are stored in the `configs` attribute which is a dict of ConfigContext + Behavior: + - Inside an active context, the configuration will be read from the session + object and reset when the context is exited. + - When no active context is present, the configuration will be read from the + session object directly. """ def __init__(self, session) -> None: @@ -539,21 +541,20 @@ def __init__(self, session) -> None: "_query_compilation_stage_enabled", "eliminate_numeric_sql_value_cast_enabled", } - self.reset() def __getattr__(self, name: str) -> Any: if name in self.configs: - return getattr(self, name) or getattr(self.session, name) - return AttributeError(f"ConfigContext has no attribute {name}") + return getattr(self.session, name) + raise AttributeError(f"ConfigContext has no attribute {name}") def __enter__(self) -> "ConfigContext": - self.create_snapshot() + self._create_snapshot() return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self.reset() + self._reset() - def create_snapshot(self) -> "ConfigContext": + def _create_snapshot(self) -> "ConfigContext": """Reads the configuration settings from the session object and stores them in the context object. """ @@ -561,9 +562,9 @@ def create_snapshot(self) -> "ConfigContext": setattr(self, name, getattr(self.session, name)) return self - def reset(self) -> None: + def _reset(self) -> None: for name in self.configs: - setattr(self, name, None) + delattr(self, name) class SnowflakePlanBuilder: diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 1efd3449b5f..a0ce197364c 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -531,9 +531,22 @@ def test_config_context(session): == original_query_compilation_stage_enabled ) - assert config_context.cte_optimization_enabled is None - assert config_context.eliminate_numeric_sql_value_cast_enabled is None - assert config_context._query_compilation_stage_enabled is None + assert ( + config_context.cte_optimization_enabled == session.cte_optimization_enabled + ) + assert ( + config_context.eliminate_numeric_sql_value_cast_enabled + == session.eliminate_numeric_sql_value_cast_enabled + ) + assert ( + config_context._query_compilation_stage_enabled + == session._query_compilation_stage_enabled + ) + + with pytest.raises( + AttributeError, match="ConfigContext has no attribute no_such_config" + ): + config_context.no_such_config finally: session.cte_optimization_enabled = original_cte_optimization session.eliminate_numeric_sql_value_cast_enabled = ( From 30952bbeb70d381c8476445c163769480c680335 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 19 Sep 2024 17:02:55 -0700 Subject: [PATCH 30/62] update documentation --- src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 5210d244ea3..88fe62c23b8 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -555,14 +555,15 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: self._reset() def _create_snapshot(self) -> "ConfigContext": - """Reads the configuration settings from the session object and stores them in the - context object. + """Reads the configuration attributes from the session object and stores them + in the context object. """ for name in self.configs: setattr(self, name, getattr(self.session, name)) return self def _reset(self) -> None: + """Removes the configuration attributes from the context object.""" for name in self.configs: delattr(self, name) From 03f25b558d43e32d358556412cac22a34deb63c0 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 19 Sep 2024 17:19:12 -0700 Subject: [PATCH 31/62] use config context in plan compiler --- .../_internal/analyzer/snowflake_plan.py | 4 +- .../_internal/compiler/plan_compiler.py | 36 +++-- src/snowflake/snowpark/mock/_analyzer.py | 153 ++++++++++-------- 3 files changed, 110 insertions(+), 83 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 88fe62c23b8..c36cc427dbd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -537,9 +537,11 @@ class ConfigContext: def __init__(self, session) -> None: self.session = session self.configs = { - "cte_optimization_enabled", "_query_compilation_stage_enabled", + "cte_optimization_enabled", "eliminate_numeric_sql_value_cast_enabled", + "large_query_breakdown_complexity_bounds", + "large_query_breakdown_enabled", } def __getattr__(self, name: str) -> Any: diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 1c9fa3c08ba..279c4d9ed96 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -10,6 +10,7 @@ get_complexity_score, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( + ConfigContext, PlanQueryType, Query, SnowflakePlan, @@ -47,14 +48,7 @@ class PlanCompiler: def __init__(self, plan: SnowflakePlan) -> None: self._plan = plan current_session = self._plan.session - self.cte_optimization_enabled = current_session.cte_optimization_enabled - self.large_query_breakdown_enabled = ( - current_session.large_query_breakdown_enabled - ) - self.query_compilation_stage_enabled = ( - current_session._query_compilation_stage_enabled - ) - self.complexity_bounds = current_session.large_query_breakdown_complexity_bounds + self.config_context = ConfigContext(current_session) def should_start_query_compilation(self) -> bool: """ @@ -74,11 +68,18 @@ def should_start_query_compilation(self) -> bool: return ( not isinstance(current_session._conn, MockServerConnection) and (self._plan.source_plan is not None) - and self.query_compilation_stage_enabled - and (self.cte_optimization_enabled or self.large_query_breakdown_enabled) + and self.config_context._query_compilation_stage_enabled + and ( + self.config_context.cte_optimization_enabled + or self.config_context.large_query_breakdown_enabled + ) ) def compile(self) -> Dict[PlanQueryType, List[Query]]: + with self.config_context: + return self._compile() + + def _compile(self) -> Dict[PlanQueryType, List[Query]]: if self.should_start_query_compilation(): # preparation for compilation # 1. make a copy of the original plan @@ -95,7 +96,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: # 3. apply each optimizations if needed # CTE optimization cte_start_time = time.time() - if self.cte_optimization_enabled: + if self.config_context.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) @@ -108,12 +109,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: ] # Large query breakdown - if self.large_query_breakdown_enabled: + if self.config_context.large_query_breakdown_enabled: large_query_breakdown = LargeQueryBreakdown( self._plan.session, query_generator, logical_plans, - self.complexity_bounds, + self.config_context.large_query_breakdown_complexity_bounds, ) logical_plans = large_query_breakdown.apply() @@ -133,9 +134,9 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: total_time = time.time() - start_time session = self._plan.session summary_value = { - TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.cte_optimization_enabled, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.large_query_breakdown_enabled, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self.complexity_bounds, + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.config_context.cte_optimization_enabled, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.config_context.large_query_breakdown_enabled, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self.config_context.large_query_breakdown_complexity_bounds, CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, @@ -153,7 +154,8 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: else: final_plan = self._plan final_plan = final_plan.replace_repeated_subquery_with_cte( - self.cte_optimization_enabled, self.query_compilation_stage_enabled + self.config_context.cte_optimization_enabled, + self.config_context._query_compilation_stage_enabled, ) return { PlanQueryType.QUERIES: final_plan.queries, diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index 6f02724e442..710ffed4db9 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -75,7 +75,10 @@ GroupingSet, GroupingSetsExpression, ) -from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.analyzer.snowflake_plan import ( + ConfigContext, + SnowflakePlan, +) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( CopyIntoLocationNode, CopyIntoTableNode, @@ -153,13 +156,7 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None: self.subquery_plans = [] self.alias_maps_to_use = None self._conn = self.session._conn - self._eliminate_numeric_sql_value_cast_enabled: Optional[bool] = None - - @property - def eliminate_numeric_sql_value_cast_enabled(self) -> bool: - if self._eliminate_numeric_sql_value_cast_enabled is None: - return self.session.eliminate_numeric_sql_value_cast_enabled - return self._eliminate_numeric_sql_value_cast_enabled + self.config_context = ConfigContext(session) def analyze( self, @@ -167,6 +164,21 @@ def analyze( expr_to_alias: Optional[Dict[str, str]] = None, parse_local_name=False, keep_alias=True, + ) -> Union[str, List[str]]: + with self.config_context: + return self.do_analyze( + expr, + expr_to_alias, + parse_local_name, + keep_alias, + ) + + def do_analyze( + self, + expr: Union[Expression, NamedExpression], + expr_to_alias: Optional[Dict[str, str]] = None, + parse_local_name=False, + keep_alias=True, ) -> Union[str, List[str]]: """ Args: @@ -187,15 +199,15 @@ def analyze( if isinstance(expr, Like): return like_expression( - self.analyze(expr.expr, expr_to_alias, parse_local_name), - self.analyze(expr.pattern, expr_to_alias, parse_local_name), + self.do_analyze(expr.expr, expr_to_alias, parse_local_name), + self.do_analyze(expr.pattern, expr_to_alias, parse_local_name), ) if isinstance(expr, RegExp): return regexp_expression( - self.analyze(expr.expr, expr_to_alias, parse_local_name), - self.analyze(expr.pattern, expr_to_alias, parse_local_name), - self.analyze(expr.parameters, expr_to_alias, parse_local_name) + self.do_analyze(expr.expr, expr_to_alias, parse_local_name), + self.do_analyze(expr.pattern, expr_to_alias, parse_local_name), + self.do_analyze(expr.parameters, expr_to_alias, parse_local_name) if expr.parameters is not None else None, ) @@ -205,7 +217,8 @@ def analyze( expr.collation_spec.upper() if parse_local_name else expr.collation_spec ) return collate_expression( - self.analyze(expr.expr, expr_to_alias, parse_local_name), collation_spec + self.do_analyze(expr.expr, expr_to_alias, parse_local_name), + collation_spec, ) if isinstance(expr, (SubfieldString, SubfieldInt)): @@ -213,19 +226,19 @@ def analyze( if parse_local_name and isinstance(field, str): field = field.upper() return subfield_expression( - self.analyze(expr.expr, expr_to_alias, parse_local_name), field + self.do_analyze(expr.expr, expr_to_alias, parse_local_name), field ) if isinstance(expr, CaseWhen): return case_when_expression( [ ( - self.analyze(condition, expr_to_alias, parse_local_name), - self.analyze(value, expr_to_alias, parse_local_name), + self.do_analyze(condition, expr_to_alias, parse_local_name), + self.do_analyze(value, expr_to_alias, parse_local_name), ) for condition, value in expr.branches ], - self.analyze(expr.else_value, expr_to_alias, parse_local_name) + self.do_analyze(expr.else_value, expr_to_alias, parse_local_name) if expr.else_value else "NULL", ) @@ -233,14 +246,14 @@ def analyze( if isinstance(expr, MultipleExpression): block_expressions = [] for expression in expr.expressions: - if self.eliminate_numeric_sql_value_cast_enabled: + if self.config_context.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( expression, expr_to_alias, parse_local_name, ) else: - resolved_expr = self.analyze( + resolved_expr = self.do_analyze( expression, expr_to_alias, parse_local_name, @@ -252,14 +265,14 @@ def analyze( if isinstance(expr, InExpression): in_values = [] for expression in expr.values: - if self.eliminate_numeric_sql_value_cast_enabled: + if self.config_context.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( expression, expr_to_alias, parse_local_name, ) else: - in_value = self.analyze( + in_value = self.do_analyze( expression, expr_to_alias, parse_local_name, @@ -267,7 +280,7 @@ def analyze( in_values.append(in_value) return in_expression( - self.analyze(expr.columns, expr_to_alias, parse_local_name), + self.do_analyze(expr.columns, expr_to_alias, parse_local_name), in_values, ) @@ -279,11 +292,11 @@ def analyze( if isinstance(expr, WindowExpression): return window_expression( - self.analyze( + self.do_analyze( expr.window_function, parse_local_name=parse_local_name, ), - self.analyze( + self.do_analyze( expr.window_spec, parse_local_name=parse_local_name, ), @@ -292,14 +305,14 @@ def analyze( if isinstance(expr, WindowSpecDefinition): return window_spec_expression( [ - self.analyze(x, parse_local_name=parse_local_name) + self.do_analyze(x, parse_local_name=parse_local_name) for x in expr.partition_spec ], [ - self.analyze(x, parse_local_name=parse_local_name) + self.do_analyze(x, parse_local_name=parse_local_name) for x in expr.order_spec ], - self.analyze( + self.do_analyze( expr.frame_spec, parse_local_name=parse_local_name, ), @@ -355,7 +368,7 @@ def analyze( if not expr.expressions: return "*" else: - return [self.analyze(e, expr_to_alias) for e in expr.expressions] + return [self.do_analyze(e, expr_to_alias) for e in expr.expressions] if isinstance(expr, SnowflakeUDF): if expr.api_call_source is not None: @@ -366,7 +379,7 @@ def analyze( return function_expression( func_name, [ - self.analyze(x, expr_to_alias, parse_local_name) + self.do_analyze(x, expr_to_alias, parse_local_name) for x in expr.children ], False, @@ -379,13 +392,13 @@ def analyze( return table_function_partition_spec( expr.over, [ - self.analyze(x, expr_to_alias, parse_local_name) + self.do_analyze(x, expr_to_alias, parse_local_name) for x in expr.partition_spec ] if expr.partition_spec else [], [ - self.analyze(x, expr_to_alias, parse_local_name) + self.do_analyze(x, expr_to_alias, parse_local_name) for x in expr.order_spec ] if expr.order_spec @@ -402,7 +415,7 @@ def analyze( if isinstance(expr, SortOrder): return order_expression( - self.analyze(expr.child, expr_to_alias, parse_local_name), + self.do_analyze(expr.child, expr_to_alias, parse_local_name), expr.direction.sql, expr.null_ordering.sql, ) @@ -413,8 +426,8 @@ def analyze( if isinstance(expr, WithinGroup): return within_group_expression( - self.analyze(expr.expr, expr_to_alias, parse_local_name), - [self.analyze(e, expr_to_alias) for e in expr.order_by_cols], + self.do_analyze(expr.expr, expr_to_alias, parse_local_name), + [self.do_analyze(e, expr_to_alias) for e in expr.order_by_cols], ) if isinstance(expr, BinaryExpression): @@ -426,28 +439,34 @@ def analyze( if isinstance(expr, InsertMergeExpression): return insert_merge_statement( - self.analyze(expr.condition, expr_to_alias) if expr.condition else None, - [self.analyze(k, expr_to_alias) for k in expr.keys], - [self.analyze(v, expr_to_alias) for v in expr.values], + self.do_analyze(expr.condition, expr_to_alias) + if expr.condition + else None, + [self.do_analyze(k, expr_to_alias) for k in expr.keys], + [self.do_analyze(v, expr_to_alias) for v in expr.values], ) if isinstance(expr, UpdateMergeExpression): return update_merge_statement( - self.analyze(expr.condition, expr_to_alias) if expr.condition else None, + self.do_analyze(expr.condition, expr_to_alias) + if expr.condition + else None, { - self.analyze(k, expr_to_alias): self.analyze(v, expr_to_alias) + self.do_analyze(k, expr_to_alias): self.do_analyze(v, expr_to_alias) for k, v in expr.assignments.items() }, ) if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( - self.analyze(expr.condition, expr_to_alias) if expr.condition else None + self.do_analyze(expr.condition, expr_to_alias) + if expr.condition + else None ) if isinstance(expr, ListAgg): return list_agg( - self.analyze(expr.col, expr_to_alias, parse_local_name), + self.do_analyze(expr.col, expr_to_alias, parse_local_name), str_to_sql(expr.delimiter), expr.is_distinct, ) @@ -455,7 +474,7 @@ def analyze( if isinstance(expr, ColumnSum): return column_sum( [ - self.analyze(col, expr_to_alias, parse_local_name) + self.do_analyze(col, expr_to_alias, parse_local_name) for col in expr.exprs ] ) @@ -463,9 +482,9 @@ def analyze( if isinstance(expr, RankRelatedFunctionExpression): return rank_related_function_expression( expr.sql, - self.analyze(expr.expr, expr_to_alias, parse_local_name), + self.do_analyze(expr.expr, expr_to_alias, parse_local_name), expr.offset, - self.analyze(expr.default, expr_to_alias, parse_local_name) + self.do_analyze(expr.default, expr_to_alias, parse_local_name) if expr.default else None, expr.ignore_nulls, @@ -483,7 +502,7 @@ def table_function_expression_extractor( ) -> str: if isinstance(expr, FlattenFunction): return flatten_expression( - self.analyze(expr.input, expr_to_alias, parse_local_name), + self.do_analyze(expr.input, expr_to_alias, parse_local_name), expr.path, expr.outer, expr.recursive, @@ -492,14 +511,17 @@ def table_function_expression_extractor( elif isinstance(expr, PosArgumentsTableFunction): sql = function_expression( expr.func_name, - [self.analyze(x, expr_to_alias, parse_local_name) for x in expr.args], + [ + self.do_analyze(x, expr_to_alias, parse_local_name) + for x in expr.args + ], False, ) elif isinstance(expr, (NamedArgumentsTableFunction, GeneratorTableFunction)): sql = named_arguments_function( expr.func_name, { - key: self.analyze(value, expr_to_alias, parse_local_name) + key: self.do_analyze(value, expr_to_alias, parse_local_name) for key, value in expr.args.items() }, ) @@ -509,7 +531,7 @@ def table_function_expression_extractor( "NamedArgumentsTableFunction, GeneratorTableFunction, or FlattenFunction." ) partition_spec_sql = ( - self.analyze(expr.partition_spec, expr_to_alias) + self.do_analyze(expr.partition_spec, expr_to_alias) if expr.partition_spec else "" ) @@ -530,26 +552,27 @@ def unary_expression_extractor( if v == expr.child.name: expr_to_alias[k] = quoted_name alias_exp = alias_expression( - self.analyze(expr.child, expr_to_alias, parse_local_name), quoted_name + self.do_analyze(expr.child, expr_to_alias, parse_local_name), + quoted_name, ) expr_str = alias_exp if keep_alias else expr.name or keep_alias expr_str = expr_str.upper() if parse_local_name else expr_str return expr_str if isinstance(expr, UnresolvedAlias): - expr_str = self.analyze(expr.child, expr_to_alias, parse_local_name) + expr_str = self.do_analyze(expr.child, expr_to_alias, parse_local_name) if parse_local_name: expr_str = expr_str.upper() return quote_name(expr_str.strip()) elif isinstance(expr, Cast): return cast_expression( - self.analyze(expr.child, expr_to_alias, parse_local_name), + self.do_analyze(expr.child, expr_to_alias, parse_local_name), expr.to, expr.try_, ) else: return unary_expression( - self.analyze(expr.child, expr_to_alias, parse_local_name), + self.do_analyze(expr.child, expr_to_alias, parse_local_name), expr.sql_operator, expr.operator_first, ) @@ -560,7 +583,7 @@ def binary_operator_extractor( expr_to_alias: Dict[str, str], parse_local_name=False, ) -> str: - if self.eliminate_numeric_sql_value_cast_enabled: + if self.config_context.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( expr.left, expr_to_alias, parse_local_name ) @@ -570,8 +593,10 @@ def binary_operator_extractor( parse_local_name, ) else: - left_sql_expr = self.analyze(expr.left, expr_to_alias, parse_local_name) - right_sql_expr = self.analyze(expr.right, expr_to_alias, parse_local_name) + left_sql_expr = self.do_analyze(expr.left, expr_to_alias, parse_local_name) + right_sql_expr = self.do_analyze( + expr.right, expr_to_alias, parse_local_name + ) operator = expr.sql_operator.lower() if isinstance(expr, BinaryArithmeticExpression): @@ -593,7 +618,7 @@ def binary_operator_extractor( def grouping_extractor( self, expr: GroupingSet, expr_to_alias: Dict[str, str] ) -> str: - return self.analyze( + return self.do_analyze( FunctionExpression( expr.pretty_name.upper(), [c.child if isinstance(c, Alias) else c for c in expr.children], @@ -617,7 +642,7 @@ def to_sql_try_avoid_cast( if isinstance(expr, Literal) and isinstance(expr.datatype, _NumericType): return numeric_to_sql_without_cast(expr.value, expr.datatype) else: - return self.analyze(expr, expr_to_alias, parse_local_name) + return self.do_analyze(expr, expr_to_alias, parse_local_name) def resolve( self, logical_plan: LogicalPlan, expr_to_alias: Optional[Dict[str, str]] = None @@ -625,12 +650,10 @@ def resolve( self.subquery_plans = [] if expr_to_alias is None: expr_to_alias = {} - self._eliminate_numeric_sql_value_cast_enabled = ( - self.session.eliminate_numeric_sql_value_cast_enabled - ) - result = self.do_resolve(logical_plan, expr_to_alias) - self._eliminate_numeric_sql_value_cast_enabled = None + with self.config_context: + result = self.do_resolve(logical_plan, expr_to_alias) + return result def do_resolve( @@ -711,7 +734,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Sort): return self.plan_builder.sort( - list(map(self.analyze, logical_plan.order)), + list(map(self.do_analyze, logical_plan.order)), resolved_children[logical_plan.child], logical_plan, ) @@ -779,7 +802,7 @@ def do_resolve_with_resolved_children( query=resolved_children[logical_plan.child], stage_location=logical_plan.stage_location, source_plan=logical_plan, - partition_by=self.analyze(logical_plan.partition_by, expr_to_alias) + partition_by=self.do_analyze(logical_plan.partition_by, expr_to_alias) if logical_plan.partition_by else None, file_format_name=logical_plan.file_format_name, From 6deb4029fa3a3608cb64d4acfb6bc7405c5e6993 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 19 Sep 2024 17:20:57 -0700 Subject: [PATCH 32/62] add comments --- tests/integ/test_multithreading.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index a0ce197364c..de7ef7c5099 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -501,6 +501,8 @@ def finish(self): def test_config_context(session): config_context = ConfigContext(session) + + # Check if all context configs are present in the session for name in config_context.configs: assert hasattr(session, name) @@ -513,6 +515,7 @@ def test_config_context(session): session._query_compilation_stage_enabled ) with config_context: + # Active context session.cte_optimization_enabled = not original_cte_optimization session.eliminate_numeric_sql_value_cast_enabled = ( not original_eliminate_numeric_sql_value_cast_enabled @@ -531,6 +534,7 @@ def test_config_context(session): == original_query_compilation_stage_enabled ) + # Context is deactivated assert ( config_context.cte_optimization_enabled == session.cte_optimization_enabled ) From 8e1dfe06b276dd2a79e8919a703e9069ad4f957b Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 19 Sep 2024 17:23:58 -0700 Subject: [PATCH 33/62] minor refactor --- src/snowflake/snowpark/_internal/compiler/plan_compiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 279c4d9ed96..171720101d4 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -47,8 +47,7 @@ class PlanCompiler: def __init__(self, plan: SnowflakePlan) -> None: self._plan = plan - current_session = self._plan.session - self.config_context = ConfigContext(current_session) + self.config_context = ConfigContext(self._plan.session) def should_start_query_compilation(self) -> bool: """ From 10bfeb4e6bf1ad259cd05f805cf63564d329b787 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 19 Sep 2024 17:41:23 -0700 Subject: [PATCH 34/62] fix test --- src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py | 5 ++++- src/snowflake/snowpark/mock/_analyzer.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index c36cc427dbd..987d2091582 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -567,7 +567,10 @@ def _create_snapshot(self) -> "ConfigContext": def _reset(self) -> None: """Removes the configuration attributes from the context object.""" for name in self.configs: - delattr(self, name) + try: + delattr(self, name) + except AttributeError: + pass class SnowflakePlanBuilder: diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index 710ffed4db9..bd85953830f 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -157,6 +157,8 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None: self.alias_maps_to_use = None self._conn = self.session._conn self.config_context = ConfigContext(session) + # Point this plan builder to the same config context as the analyzer + self.plan_builder.config_context = self.config_context def analyze( self, From 879940a03090fb4a02e7778608a96cc9b098523f Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 19 Sep 2024 17:47:11 -0700 Subject: [PATCH 35/62] update documentation --- .../snowpark/_internal/analyzer/snowflake_plan.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 987d2091582..ca6b97f57e7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -524,12 +524,12 @@ def add_aliases(self, to_add: Dict) -> None: class ConfigContext: - """Class to manage the snapshot of configuration settings in the context of plan - building, analysis and resolution. + """Class to manage reading of configuration values from session in the context of + plan building, analysis and resolution. Behavior: - - Inside an active context, the configuration will be read from the session - object and reset when the context is exited. + - Inside an active context, the configuration will be read based on snapshot taken + at the context creation stage and reset when the context is exited. - When no active context is present, the configuration will be read from the session object directly. """ From 5aad2d93d2137ea5a0432e28d8dac614f09c62d6 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 25 Sep 2024 15:24:43 -0700 Subject: [PATCH 36/62] simplify context config --- .../snowpark/_internal/analyzer/analyzer.py | 385 +++++++++++++----- .../_internal/analyzer/config_context.py | 37 ++ .../_internal/analyzer/snowflake_plan.py | 97 ++--- .../_internal/compiler/plan_compiler.py | 9 +- .../_internal/compiler/query_generator.py | 14 +- src/snowflake/snowpark/mock/_analyzer.py | 251 ++++++++---- tests/integ/test_multithreading.py | 47 +-- 7 files changed, 560 insertions(+), 280 deletions(-) create mode 100644 src/snowflake/snowpark/_internal/analyzer/config_context.py diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 68314a2e6b2..4de2700f2d0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -47,6 +47,7 @@ BinaryExpression, ) from snowflake.snowpark._internal.analyzer.binary_plan_node import Join, SetOperation +from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.datatype_mapper import ( numeric_to_sql_without_cast, str_to_sql, @@ -87,7 +88,6 @@ SelectTableFunction, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( - ConfigContext, SnowflakePlan, SnowflakePlanBuilder, ) @@ -167,9 +167,6 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None: self.generated_alias_maps = {} self.subquery_plans = [] self.alias_maps_to_use: Optional[Dict[uuid.UUID, str]] = None - self.config_context: ConfigContext = ConfigContext(session) - # Point this plan builder to the same config context as the analyzer - self.plan_builder.config_context = self.config_context def analyze( self, @@ -178,15 +175,16 @@ def analyze( parse_local_name=False, ) -> str: # Set the config context for analysis step - with self.config_context: - return self.do_analyze( - expr, df_aliased_col_name_to_real_col_name, parse_local_name - ) + config_context = ConfigContext(self.session) + return self.do_analyze( + expr, df_aliased_col_name_to_real_col_name, config_context, parse_local_name + ) def do_analyze( self, expr: Union[Expression, NamedExpression], df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], + config_context: ConfigContext, parse_local_name=False, ) -> str: if isinstance(expr, GroupingSetsExpression): @@ -194,7 +192,10 @@ def do_analyze( [ [ self.do_analyze( - a, df_aliased_col_name_to_real_col_name, parse_local_name + a, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) for a in arg ] @@ -205,24 +206,37 @@ def do_analyze( if isinstance(expr, Like): return like_expression( self.do_analyze( - expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name + expr.expr, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), self.do_analyze( - expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name + expr.pattern, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), ) if isinstance(expr, RegExp): return regexp_expression( self.do_analyze( - expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name + expr.expr, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), self.do_analyze( - expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name + expr.pattern, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), self.do_analyze( expr.parameters, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ) if expr.parameters is not None @@ -235,7 +249,10 @@ def do_analyze( ) return collate_expression( self.do_analyze( - expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name + expr.expr, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), collation_spec, ) @@ -246,7 +263,10 @@ def do_analyze( field = field.upper() return subfield_expression( self.do_analyze( - expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name + expr.expr, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), field, ) @@ -258,11 +278,13 @@ def do_analyze( self.do_analyze( condition, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ), self.do_analyze( value, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ), ) @@ -271,6 +293,7 @@ def do_analyze( self.do_analyze( expr.else_value, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ) if expr.else_value @@ -280,16 +303,18 @@ def do_analyze( if isinstance(expr, MultipleExpression): block_expressions = [] for expression in expr.expressions: - if self.config_context.eliminate_numeric_sql_value_cast_enabled: + if config_context.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ) else: resolved_expr = self.do_analyze( expression, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ) @@ -299,40 +324,49 @@ def do_analyze( if isinstance(expr, InExpression): in_values = [] for expression in expr.values: - if self.config_context.eliminate_numeric_sql_value_cast_enabled: + if config_context.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ) else: in_value = self.do_analyze( expression, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ) in_values.append(in_value) return in_expression( self.do_analyze( - expr.columns, df_aliased_col_name_to_real_col_name, parse_local_name + expr.columns, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), in_values, ) if isinstance(expr, GroupingSet): - return self.grouping_extractor(expr, df_aliased_col_name_to_real_col_name) + return self.grouping_extractor( + expr, df_aliased_col_name_to_real_col_name, config_context + ) if isinstance(expr, WindowExpression): return window_expression( self.do_analyze( expr.window_function, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ), self.do_analyze( expr.window_spec, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ), ) @@ -340,19 +374,26 @@ def do_analyze( return window_spec_expression( [ self.do_analyze( - x, df_aliased_col_name_to_real_col_name, parse_local_name + x, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) for x in expr.partition_spec ], [ self.do_analyze( - x, df_aliased_col_name_to_real_col_name, parse_local_name + x, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) for x in expr.order_spec ], self.do_analyze( expr.frame_spec, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ), ) @@ -361,12 +402,12 @@ def do_analyze( expr.frame_type.sql, self.window_frame_boundary( self.to_sql_try_avoid_cast( - expr.lower, df_aliased_col_name_to_real_col_name + expr.lower, df_aliased_col_name_to_real_col_name, config_context ) ), self.window_frame_boundary( self.to_sql_try_avoid_cast( - expr.upper, df_aliased_col_name_to_real_col_name + expr.upper, df_aliased_col_name_to_real_col_name, config_context ) ), ) @@ -410,7 +451,9 @@ def do_analyze( return function_expression( func_name, [ - self.to_sql_try_avoid_cast(c, df_aliased_col_name_to_real_col_name) + self.to_sql_try_avoid_cast( + c, df_aliased_col_name_to_real_col_name, config_context + ) for c in expr.children ], expr.is_distinct, @@ -431,7 +474,9 @@ def do_analyze( # This case is hit by df.col("*") return ",".join( [ - self.do_analyze(e, df_aliased_col_name_to_real_col_name) + self.do_analyze( + e, df_aliased_col_name_to_real_col_name, config_context + ) for e in expr.expressions ] ) @@ -446,7 +491,10 @@ def do_analyze( func_name, [ self.do_analyze( - x, df_aliased_col_name_to_real_col_name, parse_local_name + x, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) for x in expr.children ], @@ -459,7 +507,7 @@ def do_analyze( expr.api_call_source, TelemetryField.FUNC_CAT_USAGE.value ) return self.table_function_expression_extractor( - expr, df_aliased_col_name_to_real_col_name + expr, df_aliased_col_name_to_real_col_name, config_context ) if isinstance(expr, TableFunctionPartitionSpecDefinition): @@ -467,7 +515,10 @@ def do_analyze( expr.over, [ self.do_analyze( - x, df_aliased_col_name_to_real_col_name, parse_local_name + x, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) for x in expr.partition_spec ] @@ -475,7 +526,10 @@ def do_analyze( else [], [ self.do_analyze( - x, df_aliased_col_name_to_real_col_name, parse_local_name + x, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) for x in expr.order_spec ] @@ -485,13 +539,19 @@ def do_analyze( if isinstance(expr, UnaryExpression): return self.unary_expression_extractor( - expr, df_aliased_col_name_to_real_col_name, parse_local_name + expr, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) if isinstance(expr, SortOrder): return order_expression( self.do_analyze( - expr.child, df_aliased_col_name_to_real_col_name, parse_local_name + expr.child, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), expr.direction.sql, expr.null_ordering.sql, @@ -504,50 +564,70 @@ def do_analyze( if isinstance(expr, WithinGroup): return within_group_expression( self.do_analyze( - expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name + expr.expr, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), [ - self.do_analyze(e, df_aliased_col_name_to_real_col_name) + self.do_analyze( + e, df_aliased_col_name_to_real_col_name, config_context + ) for e in expr.order_by_cols ], ) if isinstance(expr, BinaryExpression): return self.binary_operator_extractor( - expr, df_aliased_col_name_to_real_col_name, parse_local_name + expr, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) if isinstance(expr, InsertMergeExpression): return insert_merge_statement( - self.do_analyze(expr.condition, df_aliased_col_name_to_real_col_name) + self.do_analyze( + expr.condition, df_aliased_col_name_to_real_col_name, config_context + ) if expr.condition else None, [ - self.do_analyze(k, df_aliased_col_name_to_real_col_name) + self.do_analyze( + k, df_aliased_col_name_to_real_col_name, config_context + ) for k in expr.keys ], [ - self.do_analyze(v, df_aliased_col_name_to_real_col_name) + self.do_analyze( + v, df_aliased_col_name_to_real_col_name, config_context + ) for v in expr.values ], ) if isinstance(expr, UpdateMergeExpression): return update_merge_statement( - self.do_analyze(expr.condition, df_aliased_col_name_to_real_col_name) + self.do_analyze( + expr.condition, df_aliased_col_name_to_real_col_name, config_context + ) if expr.condition else None, { self.do_analyze( - k, df_aliased_col_name_to_real_col_name - ): self.do_analyze(v, df_aliased_col_name_to_real_col_name) + k, df_aliased_col_name_to_real_col_name, config_context + ): self.do_analyze( + v, df_aliased_col_name_to_real_col_name, config_context + ) for k, v in expr.assignments.items() }, ) if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( - self.do_analyze(expr.condition, df_aliased_col_name_to_real_col_name) + self.do_analyze( + expr.condition, df_aliased_col_name_to_real_col_name, config_context + ) if expr.condition else None ) @@ -555,7 +635,10 @@ def do_analyze( if isinstance(expr, ListAgg): return list_agg( self.do_analyze( - expr.col, df_aliased_col_name_to_real_col_name, parse_local_name + expr.col, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), str_to_sql(expr.delimiter), expr.is_distinct, @@ -565,7 +648,10 @@ def do_analyze( return column_sum( [ self.do_analyze( - col, df_aliased_col_name_to_real_col_name, parse_local_name + col, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) for col in expr.exprs ] @@ -575,11 +661,17 @@ def do_analyze( return rank_related_function_expression( expr.sql, self.do_analyze( - expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name + expr.expr, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), expr.offset, self.do_analyze( - expr.default, df_aliased_col_name_to_real_col_name, parse_local_name + expr.default, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) if expr.default else None, @@ -594,12 +686,16 @@ def table_function_expression_extractor( self, expr: TableFunctionExpression, df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], + config_context: ConfigContext, parse_local_name=False, ) -> str: if isinstance(expr, FlattenFunction): return flatten_expression( self.do_analyze( - expr.input, df_aliased_col_name_to_real_col_name, parse_local_name + expr.input, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), expr.path, expr.outer, @@ -611,7 +707,10 @@ def table_function_expression_extractor( expr.func_name, [ self.do_analyze( - x, df_aliased_col_name_to_real_col_name, parse_local_name + x, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) for x in expr.args ], @@ -622,7 +721,10 @@ def table_function_expression_extractor( expr.func_name, { key: self.to_sql_try_avoid_cast( - value, df_aliased_col_name_to_real_col_name, parse_local_name + value, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) for key, value in expr.args.items() }, @@ -633,7 +735,11 @@ def table_function_expression_extractor( "NamedArgumentsTableFunction, GeneratorTableFunction, or FlattenFunction." ) partition_spec_sql = ( - self.do_analyze(expr.partition_spec, df_aliased_col_name_to_real_col_name) + self.do_analyze( + expr.partition_spec, + df_aliased_col_name_to_real_col_name, + config_context, + ) if expr.partition_spec else "" ) @@ -643,6 +749,7 @@ def unary_expression_extractor( self, expr: UnaryExpression, df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], + config_context: ConfigContext, parse_local_name=False, ) -> str: if isinstance(expr, Alias): @@ -660,13 +767,19 @@ def unary_expression_extractor( df_alias_dict[k] = quoted_name return alias_expression( self.do_analyze( - expr.child, df_aliased_col_name_to_real_col_name, parse_local_name + expr.child, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), quoted_name, ) if isinstance(expr, UnresolvedAlias): expr_str = self.do_analyze( - expr.child, df_aliased_col_name_to_real_col_name, parse_local_name + expr.child, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) if parse_local_name: expr_str = expr_str.upper() @@ -674,7 +787,10 @@ def unary_expression_extractor( elif isinstance(expr, Cast): return cast_expression( self.do_analyze( - expr.child, df_aliased_col_name_to_real_col_name, parse_local_name + expr.child, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), expr.to, expr.try_, @@ -682,7 +798,10 @@ def unary_expression_extractor( else: return unary_expression( self.do_analyze( - expr.child, df_aliased_col_name_to_real_col_name, parse_local_name + expr.child, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ), expr.sql_operator, expr.operator_first, @@ -692,23 +811,34 @@ def binary_operator_extractor( self, expr: BinaryExpression, df_aliased_col_name_to_real_col_name, + config_context: ConfigContext, parse_local_name=False, ) -> str: - if self.config_context.eliminate_numeric_sql_value_cast_enabled: + if config_context.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( - expr.left, df_aliased_col_name_to_real_col_name, parse_local_name + expr.left, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) right_sql_expr = self.to_sql_try_avoid_cast( expr.right, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, ) else: left_sql_expr = self.do_analyze( - expr.left, df_aliased_col_name_to_real_col_name, parse_local_name + expr.left, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) right_sql_expr = self.do_analyze( - expr.right, df_aliased_col_name_to_real_col_name, parse_local_name + expr.right, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) if isinstance(expr, BinaryArithmeticExpression): return binary_arithmetic_expression( @@ -727,7 +857,7 @@ def binary_operator_extractor( ) def grouping_extractor( - self, expr: GroupingSet, df_aliased_col_name_to_real_col_name + self, expr: GroupingSet, df_aliased_col_name_to_real_col_name, config_context ) -> str: return self.do_analyze( FunctionExpression( @@ -736,6 +866,7 @@ def grouping_extractor( False, ), df_aliased_col_name_to_real_col_name, + config_context, ) def window_frame_boundary(self, offset: str) -> str: @@ -749,6 +880,7 @@ def to_sql_try_avoid_cast( self, expr: Expression, df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], + config_context: ConfigContext, parse_local_name: bool = False, ) -> str: """ @@ -769,15 +901,22 @@ def to_sql_try_avoid_cast( return str(expr.value).upper() else: return self.do_analyze( - expr, df_aliased_col_name_to_real_col_name, parse_local_name + expr, + df_aliased_col_name_to_real_col_name, + config_context, + parse_local_name, ) - def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: + def resolve( + self, logical_plan: LogicalPlan, config_context: Optional[ConfigContext] = None + ) -> SnowflakePlan: self.subquery_plans = [] self.generated_alias_maps = {} + if config_context is None: + config_context = ConfigContext(self.session) + self.plan_builder.config_context = config_context - with self.config_context: - result = self.do_resolve(logical_plan) + result = self.do_resolve(logical_plan, config_context) result.add_aliases(self.generated_alias_maps) @@ -786,14 +925,16 @@ def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: return result - def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: + def do_resolve( + self, logical_plan: LogicalPlan, config_context: ConfigContext + ) -> SnowflakePlan: resolved_children = {} df_aliased_col_name_to_real_col_name: DefaultDict[ str, Dict[str, str] ] = defaultdict(dict) for c in logical_plan.children: # post-order traversal of the tree - resolved = self.resolve(c) + resolved = self.resolve(c, config_context) df_aliased_col_name_to_real_col_name.update( resolved.df_aliased_col_name_to_real_col_name ) @@ -821,7 +962,10 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: self.alias_maps_to_use = use_maps res = self.do_resolve_with_resolved_children( - logical_plan, resolved_children, df_aliased_col_name_to_real_col_name + logical_plan, + resolved_children, + df_aliased_col_name_to_real_col_name, + config_context, ) res.df_aliased_col_name_to_real_col_name.update( df_aliased_col_name_to_real_col_name @@ -833,6 +977,7 @@ def do_resolve_with_resolved_children( logical_plan: LogicalPlan, resolved_children: Dict[LogicalPlan, SnowflakePlan], df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], + config_context: ConfigContext, ) -> SnowflakePlan: if isinstance(logical_plan, SnowflakePlan): return logical_plan @@ -840,7 +985,9 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableFunctionJoin): return self.plan_builder.join_table_function( self.do_analyze( - logical_plan.table_function, df_aliased_col_name_to_real_col_name + logical_plan.table_function, + df_aliased_col_name_to_real_col_name, + config_context, ), resolved_children[logical_plan.children[0]], logical_plan, @@ -852,7 +999,9 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableFunctionRelation): return self.plan_builder.from_table_function( self.do_analyze( - logical_plan.table_function, df_aliased_col_name_to_real_col_name + logical_plan.table_function, + df_aliased_col_name_to_real_col_name, + config_context, ), logical_plan, ) @@ -860,7 +1009,9 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Lateral): return self.plan_builder.lateral( self.do_analyze( - logical_plan.table_function, df_aliased_col_name_to_real_col_name + logical_plan.table_function, + df_aliased_col_name_to_real_col_name, + config_context, ), resolved_children[logical_plan.children[0]], logical_plan, @@ -870,12 +1021,14 @@ def do_resolve_with_resolved_children( return self.plan_builder.aggregate( [ self.to_sql_try_avoid_cast( - expr, df_aliased_col_name_to_real_col_name + expr, df_aliased_col_name_to_real_col_name, config_context ) for expr in logical_plan.grouping_expressions ], [ - self.do_analyze(expr, df_aliased_col_name_to_real_col_name) + self.do_analyze( + expr, df_aliased_col_name_to_real_col_name, config_context + ) for expr in logical_plan.aggregate_expressions ], resolved_children[logical_plan.child], @@ -887,7 +1040,7 @@ def do_resolve_with_resolved_children( list( map( lambda x: self.do_analyze( - x, df_aliased_col_name_to_real_col_name + x, df_aliased_col_name_to_real_col_name, config_context ), logical_plan.project_list, ) @@ -899,7 +1052,9 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Filter): return self.plan_builder.filter( self.do_analyze( - logical_plan.condition, df_aliased_col_name_to_real_col_name + logical_plan.condition, + df_aliased_col_name_to_real_col_name, + config_context, ), resolved_children[logical_plan.child], logical_plan, @@ -917,14 +1072,18 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Join): join_condition = ( self.do_analyze( - logical_plan.join_condition, df_aliased_col_name_to_real_col_name + logical_plan.join_condition, + df_aliased_col_name_to_real_col_name, + config_context, ) if logical_plan.join_condition else "" ) match_condition = ( self.do_analyze( - logical_plan.match_condition, df_aliased_col_name_to_real_col_name + logical_plan.match_condition, + df_aliased_col_name_to_real_col_name, + config_context, ) if logical_plan.match_condition else "" @@ -942,7 +1101,9 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Sort): return self.plan_builder.sort( [ - self.do_analyze(x, df_aliased_col_name_to_real_col_name) + self.do_analyze( + x, df_aliased_col_name_to_real_col_name, config_context + ) for x in logical_plan.order ], resolved_children[logical_plan.child], @@ -1006,7 +1167,9 @@ def do_resolve_with_resolved_children( mode=logical_plan.mode, table_type=logical_plan.table_type, clustering_keys=[ - self.do_analyze(x, df_aliased_col_name_to_real_col_name) + self.do_analyze( + x, df_aliased_col_name_to_real_col_name, config_context + ) for x in logical_plan.clustering_exprs ], comment=logical_plan.comment, @@ -1029,10 +1192,14 @@ def do_resolve_with_resolved_children( ) and isinstance(logical_plan.child.source_plan, Sort) return self.plan_builder.limit( self.to_sql_try_avoid_cast( - logical_plan.limit_expr, df_aliased_col_name_to_real_col_name + logical_plan.limit_expr, + df_aliased_col_name_to_real_col_name, + config_context, ), self.to_sql_try_avoid_cast( - logical_plan.offset_expr, df_aliased_col_name_to_real_col_name + logical_plan.offset_expr, + df_aliased_col_name_to_real_col_name, + config_context, ), resolved_children[logical_plan.child], on_top_of_order_by, @@ -1060,7 +1227,9 @@ def do_resolve_with_resolved_children( ] child = self.plan_builder.project( [ - self.do_analyze(col, df_aliased_col_name_to_real_col_name) + self.do_analyze( + col, df_aliased_col_name_to_real_col_name, config_context + ) for col in project_exprs ], resolved_children[logical_plan.child], @@ -1076,26 +1245,36 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan.pivot_values, List): pivot_values = [ - self.do_analyze(pv, df_aliased_col_name_to_real_col_name) + self.do_analyze( + pv, df_aliased_col_name_to_real_col_name, config_context + ) for pv in logical_plan.pivot_values ] elif isinstance(logical_plan.pivot_values, ScalarSubquery): pivot_values = self.do_analyze( - logical_plan.pivot_values, df_aliased_col_name_to_real_col_name + logical_plan.pivot_values, + df_aliased_col_name_to_real_col_name, + config_context, ) else: pivot_values = None pivot_plan = self.plan_builder.pivot( self.do_analyze( - logical_plan.pivot_column, df_aliased_col_name_to_real_col_name + logical_plan.pivot_column, + df_aliased_col_name_to_real_col_name, + config_context, ), pivot_values, self.do_analyze( - logical_plan.aggregates[0], df_aliased_col_name_to_real_col_name + logical_plan.aggregates[0], + df_aliased_col_name_to_real_col_name, + config_context, ), self.do_analyze( - logical_plan.default_on_null, df_aliased_col_name_to_real_col_name + logical_plan.default_on_null, + df_aliased_col_name_to_real_col_name, + config_context, ) if logical_plan.default_on_null else None, @@ -1120,7 +1299,9 @@ def do_resolve_with_resolved_children( logical_plan.value_column, logical_plan.name_column, [ - self.do_analyze(c, df_aliased_col_name_to_real_col_name) + self.do_analyze( + c, df_aliased_col_name_to_real_col_name, config_context + ) for c in logical_plan.column_list ], resolved_children[logical_plan.child], @@ -1162,7 +1343,9 @@ def do_resolve_with_resolved_children( refresh_mode=logical_plan.refresh_mode, initialize=logical_plan.initialize, clustering_keys=[ - self.do_analyze(x, df_aliased_col_name_to_real_col_name) + self.do_analyze( + x, df_aliased_col_name_to_real_col_name, config_context + ) for x in logical_plan.clustering_exprs ], is_transient=logical_plan.is_transient, @@ -1194,7 +1377,9 @@ def do_resolve_with_resolved_children( validation_mode=logical_plan.validation_mode, column_names=logical_plan.column_names, transformations=[ - self.do_analyze(x, df_aliased_col_name_to_real_col_name) + self.do_analyze( + x, df_aliased_col_name_to_real_col_name, config_context + ) for x in logical_plan.transformations ] if logical_plan.transformations @@ -1210,7 +1395,9 @@ def do_resolve_with_resolved_children( stage_location=logical_plan.stage_location, source_plan=logical_plan, partition_by=self.do_analyze( - logical_plan.partition_by, df_aliased_col_name_to_real_col_name + logical_plan.partition_by, + df_aliased_col_name_to_real_col_name, + config_context, ) if logical_plan.partition_by else None, @@ -1226,12 +1413,16 @@ def do_resolve_with_resolved_children( logical_plan.table_name, { self.do_analyze( - k, df_aliased_col_name_to_real_col_name - ): self.do_analyze(v, df_aliased_col_name_to_real_col_name) + k, df_aliased_col_name_to_real_col_name, config_context + ): self.do_analyze( + v, df_aliased_col_name_to_real_col_name, config_context + ) for k, v in logical_plan.assignments.items() }, self.do_analyze( - logical_plan.condition, df_aliased_col_name_to_real_col_name + logical_plan.condition, + df_aliased_col_name_to_real_col_name, + config_context, ) if logical_plan.condition else None, @@ -1245,7 +1436,9 @@ def do_resolve_with_resolved_children( return self.plan_builder.delete( logical_plan.table_name, self.do_analyze( - logical_plan.condition, df_aliased_col_name_to_real_col_name + logical_plan.condition, + df_aliased_col_name_to_real_col_name, + config_context, ) if logical_plan.condition else None, @@ -1263,10 +1456,14 @@ def do_resolve_with_resolved_children( if logical_plan.source else logical_plan.source, self.do_analyze( - logical_plan.join_expr, df_aliased_col_name_to_real_col_name + logical_plan.join_expr, + df_aliased_col_name_to_real_col_name, + config_context, ), [ - self.do_analyze(c, df_aliased_col_name_to_real_col_name) + self.do_analyze( + c, df_aliased_col_name_to_real_col_name, config_context + ) for c in logical_plan.clauses ], logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/config_context.py b/src/snowflake/snowpark/_internal/analyzer/config_context.py new file mode 100644 index 00000000000..c45d335d8a2 --- /dev/null +++ b/src/snowflake/snowpark/_internal/analyzer/config_context.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + + +class ConfigContext: + """Class to help reading a snapshot of configuration attributes from a session object. + + On instantiation, this object stores the configuration from the session object + and returns the stored configuration attributes when requested. + """ + + def __init__(self, session) -> None: + self.session = session + self.configs = { + "_query_compilation_stage_enabled", + "cte_optimization_enabled", + "eliminate_numeric_sql_value_cast_enabled", + "large_query_breakdown_complexity_bounds", + "large_query_breakdown_enabled", + } + self._create_snapshot() + + def __getattr__(self, name: str) -> Any: + if name in self.configs: + return getattr(self.session, name) + raise AttributeError(f"ConfigContext has no attribute {name}") + + def _create_snapshot(self) -> "ConfigContext": + """Reads the configuration attributes from the session object and stores them + in the context object. + """ + for name in self.configs: + setattr(self, name, getattr(self.session, name)) + return self diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index ca6b97f57e7..34fad4a05ad 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -23,6 +23,7 @@ Union, ) +from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, sum_node_complexities, @@ -313,13 +314,16 @@ def children_plan_nodes(self) -> List[Union["Selectable", "SnowflakePlan"]]: return [] def replace_repeated_subquery_with_cte( - self, cte_optimization_enabled: bool, query_compilation_stage_enabled: bool + self, config_context: ConfigContext ) -> "SnowflakePlan": # parameter protection # the common subquery elimination will be applied if cte_optimization is not enabled # and the new compilation stage is not enabled. When new compilation stage is enabled, # the common subquery elimination will be done through the new plan transformation. - if not cte_optimization_enabled or query_compilation_stage_enabled: + if ( + not config_context.cte_optimization_enabled + or config_context.query_compilation_stage_enabled + ): return self # if source_plan or placeholder_query is none, it must be a leaf node, @@ -358,7 +362,9 @@ def replace_repeated_subquery_with_cte( # copy depends on the cte_optimization_enabled value. We should keep it # consistent with the current context. original_cte_optimization = self.session.cte_optimization_enabled - self.session.cte_optimization_enabled = cte_optimization_enabled + self.session.cte_optimization_enabled = ( + config_context.cte_optimization_enabled + ) plan = copy.copy(self) self.session.cte_optimization_enabled = original_cte_optimization # all other parts of query are unchanged, but just replace the original query @@ -523,56 +529,6 @@ def add_aliases(self, to_add: Dict) -> None: self.expr_to_alias = {**self.expr_to_alias, **to_add} -class ConfigContext: - """Class to manage reading of configuration values from session in the context of - plan building, analysis and resolution. - - Behavior: - - Inside an active context, the configuration will be read based on snapshot taken - at the context creation stage and reset when the context is exited. - - When no active context is present, the configuration will be read from the - session object directly. - """ - - def __init__(self, session) -> None: - self.session = session - self.configs = { - "_query_compilation_stage_enabled", - "cte_optimization_enabled", - "eliminate_numeric_sql_value_cast_enabled", - "large_query_breakdown_complexity_bounds", - "large_query_breakdown_enabled", - } - - def __getattr__(self, name: str) -> Any: - if name in self.configs: - return getattr(self.session, name) - raise AttributeError(f"ConfigContext has no attribute {name}") - - def __enter__(self) -> "ConfigContext": - self._create_snapshot() - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self._reset() - - def _create_snapshot(self) -> "ConfigContext": - """Reads the configuration attributes from the session object and stores them - in the context object. - """ - for name in self.configs: - setattr(self, name, getattr(self.session, name)) - return self - - def _reset(self) -> None: - """Removes the configuration attributes from the context object.""" - for name in self.configs: - try: - delattr(self, name) - except AttributeError: - pass - - class SnowflakePlanBuilder: def __init__( self, @@ -586,7 +542,8 @@ def __init__( # on the optimized plan. During the final query generation, no schema query is needed, # this helps reduces un-necessary overhead for the describing call. self._skip_schema_query = skip_schema_query - self.config_context: ConfigContext = ConfigContext(session) + # TODO: SNOW-1541096 remove after old cte implementation is removed + self.config_context: Optional[ConfigContext] = None @SnowflakePlan.Decorator.wrap_exception def build( @@ -620,10 +577,10 @@ def build( ), "No schema query is available in child SnowflakePlan" new_schema_query = schema_query or sql_generator(child.schema_query) + config_context = self.config_context or ConfigContext(self.session) placeholder_query = ( sql_generator(select_child._id) - if self.config_context.cte_optimization_enabled - and select_child._id is not None + if config_context.cte_optimization_enabled and select_child._id is not None else None ) @@ -660,9 +617,10 @@ def build_binary( right_schema_query = schema_value_statement(select_right.attributes) schema_query = sql_generator(left_schema_query, right_schema_query) + config_context = self.config_context or ConfigContext(self.session) placeholder_query = ( sql_generator(select_left._id, select_right._id) - if self.config_context.cte_optimization_enabled + if config_context.cte_optimization_enabled and select_left._id is not None and select_right._id is not None else None @@ -694,8 +652,8 @@ def build_binary( referenced_ctes: Set[str] = set() if ( - self.config_context.cte_optimization_enabled - and self.config_context._query_compilation_stage_enabled + config_context.cte_optimization_enabled + and config_context._query_compilation_stage_enabled ): # When the cte optimization and the new compilation stage is enabled, # the referred cte tables are propagated from left and right can have @@ -986,8 +944,7 @@ def save_as_table( ) child = child.replace_repeated_subquery_with_cte( - self.config_context.cte_optimization_enabled, - self.config_context._query_compilation_stage_enabled, + self.config_context or ConfigContext(self.session) ) def get_create_table_as_select_plan(child: SnowflakePlan, replace, error): @@ -1177,8 +1134,7 @@ def create_or_replace_view( raise SnowparkClientExceptionMessages.PLAN_CREATE_VIEWS_FROM_SELECT_ONLY() child = child.replace_repeated_subquery_with_cte( - self.config_context.cte_optimization_enabled, - self.config_context._query_compilation_stage_enabled, + self.config_context or ConfigContext(self.session) ) return self.build( lambda x: create_or_replace_view_statement(name, x, is_temp, comment), @@ -1223,8 +1179,7 @@ def create_or_replace_dynamic_table( raise ValueError(f"Unknown create mode: {create_mode}") # pragma: no cover child = child.replace_repeated_subquery_with_cte( - self.config_context.cte_optimization_enabled, - self.config_context._query_compilation_stage_enabled, + self.config_context or ConfigContext(self.session) ) return self.build( lambda x: create_or_replace_dynamic_table_statement( @@ -1529,8 +1484,7 @@ def copy_into_location( **copy_options: Optional[Any], ) -> SnowflakePlan: query = query.replace_repeated_subquery_with_cte( - self.config_context.cte_optimization_enabled, - self.config_context._query_compilation_stage_enabled, + self.config_context or ConfigContext(self.session) ) return self.build( lambda x: copy_into_location( @@ -1559,8 +1513,7 @@ def update( ) -> SnowflakePlan: if source_data: source_data = source_data.replace_repeated_subquery_with_cte( - self.config_context.cte_optimization_enabled, - self.config_context._query_compilation_stage_enabled, + self.config_context or ConfigContext(self.session) ) return self.build( lambda x: update_statement( @@ -1593,8 +1546,7 @@ def delete( ) -> SnowflakePlan: if source_data: source_data = source_data.replace_repeated_subquery_with_cte( - self.config_context.cte_optimization_enabled, - self.config_context._query_compilation_stage_enabled, + self.config_context or ConfigContext(self.session) ) return self.build( lambda x: delete_statement( @@ -1625,8 +1577,7 @@ def merge( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: source_data = source_data.replace_repeated_subquery_with_cte( - self.config_context.cte_optimization_enabled, - self.config_context._query_compilation_stage_enabled, + self.config_context or ConfigContext(self.session) ) return self.build( lambda x: merge_statement(table_name, x, join_expr, clauses), diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 171720101d4..71b6ce16b97 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -6,11 +6,11 @@ import time from typing import Dict, List +from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( get_complexity_score, ) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( - ConfigContext, PlanQueryType, Query, SnowflakePlan, @@ -75,10 +75,6 @@ def should_start_query_compilation(self) -> bool: ) def compile(self) -> Dict[PlanQueryType, List[Query]]: - with self.config_context: - return self._compile() - - def _compile(self) -> Dict[PlanQueryType, List[Query]]: if self.should_start_query_compilation(): # preparation for compilation # 1. make a copy of the original plan @@ -153,8 +149,7 @@ def _compile(self) -> Dict[PlanQueryType, List[Query]]: else: final_plan = self._plan final_plan = final_plan.replace_repeated_subquery_with_cte( - self.config_context.cte_optimization_enabled, - self.config_context._query_compilation_stage_enabled, + self.config_context ) return { PlanQueryType.QUERIES: final_plan.queries, diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index 2cde864c062..dd6e56bce76 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -8,6 +8,7 @@ from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.analyzer.select_statement import Selectable from snowflake.snowpark._internal.analyzer.snowflake_plan import ( + ConfigContext, CreateViewCommand, PlanQueryType, Query, @@ -110,13 +111,14 @@ def do_resolve_with_resolved_children( logical_plan: LogicalPlan, resolved_children: Dict[LogicalPlan, SnowflakePlan], df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], + config_context: ConfigContext, ) -> SnowflakePlan: if isinstance(logical_plan, SnowflakePlan): if logical_plan.queries is None: assert logical_plan.source_plan is not None # when encounter a SnowflakePlan with no queries, try to re-resolve # the source plan to construct the result - res = self.do_resolve(logical_plan.source_plan) + res = self.do_resolve(logical_plan.source_plan, config_context) resolved_children[logical_plan] = res resolved_plan = res else: @@ -204,7 +206,10 @@ def do_resolve_with_resolved_children( copied_resolved_child.queries = final_queries[PlanQueryType.QUERIES] resolved_children[logical_plan.children[0]] = copied_resolved_child resolved_plan = super().do_resolve_with_resolved_children( - logical_plan, resolved_children, df_aliased_col_name_to_real_col_name + logical_plan, + resolved_children, + df_aliased_col_name_to_real_col_name, + config_context, ) elif isinstance(logical_plan, Selectable): @@ -228,7 +233,10 @@ def do_resolve_with_resolved_children( else: resolved_plan = super().do_resolve_with_resolved_children( - logical_plan, resolved_children, df_aliased_col_name_to_real_col_name + logical_plan, + resolved_children, + df_aliased_col_name_to_real_col_name, + config_context, ) resolved_plan._is_valid_for_replacement = True diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index bd85953830f..f26fa786126 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -44,6 +44,7 @@ BinaryExpression, ) from snowflake.snowpark._internal.analyzer.binary_plan_node import Join, SetOperation +from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.datatype_mapper import ( numeric_to_sql_without_cast, str_to_sql, @@ -75,10 +76,7 @@ GroupingSet, GroupingSetsExpression, ) -from snowflake.snowpark._internal.analyzer.snowflake_plan import ( - ConfigContext, - SnowflakePlan, -) +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( CopyIntoLocationNode, CopyIntoTableNode, @@ -156,9 +154,6 @@ def __init__(self, session: "snowflake.snowpark.session.Session") -> None: self.subquery_plans = [] self.alias_maps_to_use = None self._conn = self.session._conn - self.config_context = ConfigContext(session) - # Point this plan builder to the same config context as the analyzer - self.plan_builder.config_context = self.config_context def analyze( self, @@ -167,17 +162,19 @@ def analyze( parse_local_name=False, keep_alias=True, ) -> Union[str, List[str]]: - with self.config_context: - return self.do_analyze( - expr, - expr_to_alias, - parse_local_name, - keep_alias, - ) + config_context = ConfigContext(self.session) + return self.do_analyze( + expr, + config_context, + expr_to_alias, + parse_local_name, + keep_alias, + ) def do_analyze( self, expr: Union[Expression, NamedExpression], + config_context: ConfigContext, expr_to_alias: Optional[Dict[str, str]] = None, parse_local_name=False, keep_alias=True, @@ -201,15 +198,25 @@ def do_analyze( if isinstance(expr, Like): return like_expression( - self.do_analyze(expr.expr, expr_to_alias, parse_local_name), - self.do_analyze(expr.pattern, expr_to_alias, parse_local_name), + self.do_analyze( + expr.expr, config_context, expr_to_alias, parse_local_name + ), + self.do_analyze( + expr.pattern, config_context, expr_to_alias, parse_local_name + ), ) if isinstance(expr, RegExp): return regexp_expression( - self.do_analyze(expr.expr, expr_to_alias, parse_local_name), - self.do_analyze(expr.pattern, expr_to_alias, parse_local_name), - self.do_analyze(expr.parameters, expr_to_alias, parse_local_name) + self.do_analyze( + expr.expr, config_context, expr_to_alias, parse_local_name + ), + self.do_analyze( + expr.pattern, config_context, expr_to_alias, parse_local_name + ), + self.do_analyze( + expr.parameters, config_context, expr_to_alias, parse_local_name + ) if expr.parameters is not None else None, ) @@ -219,7 +226,9 @@ def do_analyze( expr.collation_spec.upper() if parse_local_name else expr.collation_spec ) return collate_expression( - self.do_analyze(expr.expr, expr_to_alias, parse_local_name), + self.do_analyze( + expr.expr, config_context, expr_to_alias, parse_local_name + ), collation_spec, ) @@ -228,19 +237,28 @@ def do_analyze( if parse_local_name and isinstance(field, str): field = field.upper() return subfield_expression( - self.do_analyze(expr.expr, expr_to_alias, parse_local_name), field + self.do_analyze( + expr.expr, config_context, expr_to_alias, parse_local_name + ), + field, ) if isinstance(expr, CaseWhen): return case_when_expression( [ ( - self.do_analyze(condition, expr_to_alias, parse_local_name), - self.do_analyze(value, expr_to_alias, parse_local_name), + self.do_analyze( + condition, config_context, expr_to_alias, parse_local_name + ), + self.do_analyze( + value, config_context, expr_to_alias, parse_local_name + ), ) for condition, value in expr.branches ], - self.do_analyze(expr.else_value, expr_to_alias, parse_local_name) + self.do_analyze( + expr.else_value, config_context, expr_to_alias, parse_local_name + ) if expr.else_value else "NULL", ) @@ -248,15 +266,17 @@ def do_analyze( if isinstance(expr, MultipleExpression): block_expressions = [] for expression in expr.expressions: - if self.config_context.eliminate_numeric_sql_value_cast_enabled: + if config_context.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( expression, expr_to_alias, + config_context, parse_local_name, ) else: resolved_expr = self.do_analyze( expression, + config_context, expr_to_alias, parse_local_name, ) @@ -267,22 +287,26 @@ def do_analyze( if isinstance(expr, InExpression): in_values = [] for expression in expr.values: - if self.config_context.eliminate_numeric_sql_value_cast_enabled: + if config_context.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( expression, expr_to_alias, + config_context, parse_local_name, ) else: in_value = self.do_analyze( expression, + config_context, expr_to_alias, parse_local_name, ) in_values.append(in_value) return in_expression( - self.do_analyze(expr.columns, expr_to_alias, parse_local_name), + self.do_analyze( + expr.columns, config_context, expr_to_alias, parse_local_name + ), in_values, ) @@ -296,10 +320,12 @@ def do_analyze( return window_expression( self.do_analyze( expr.window_function, + config_context, parse_local_name=parse_local_name, ), self.do_analyze( expr.window_spec, + config_context, parse_local_name=parse_local_name, ), ) @@ -307,15 +333,20 @@ def do_analyze( if isinstance(expr, WindowSpecDefinition): return window_spec_expression( [ - self.do_analyze(x, parse_local_name=parse_local_name) + self.do_analyze( + x, config_context, parse_local_name=parse_local_name + ) for x in expr.partition_spec ], [ - self.do_analyze(x, parse_local_name=parse_local_name) + self.do_analyze( + x, config_context, parse_local_name=parse_local_name + ) for x in expr.order_spec ], self.do_analyze( expr.frame_spec, + config_context, parse_local_name=parse_local_name, ), ) @@ -323,8 +354,12 @@ def do_analyze( if isinstance(expr, SpecifiedWindowFrame): return specified_window_frame_expression( expr.frame_type.sql, - self.window_frame_boundary(self.to_sql_try_avoid_cast(expr.lower, {})), - self.window_frame_boundary(self.to_sql_try_avoid_cast(expr.upper, {})), + self.window_frame_boundary( + self.to_sql_try_avoid_cast(expr.lower, {}, config_context) + ), + self.window_frame_boundary( + self.to_sql_try_avoid_cast(expr.upper, {}, config_context) + ), ) if isinstance(expr, UnspecifiedFrame): @@ -354,7 +389,7 @@ def do_analyze( children = [] for c in expr.children: - extracted = self.to_sql_try_avoid_cast(c, expr_to_alias) + extracted = self.to_sql_try_avoid_cast(c, expr_to_alias, config_context) if isinstance(extracted, list): children.extend(extracted) else: @@ -370,7 +405,10 @@ def do_analyze( if not expr.expressions: return "*" else: - return [self.do_analyze(e, expr_to_alias) for e in expr.expressions] + return [ + self.do_analyze(e, config_context, expr_to_alias) + for e in expr.expressions + ] if isinstance(expr, SnowflakeUDF): if expr.api_call_source is not None: @@ -381,26 +419,28 @@ def do_analyze( return function_expression( func_name, [ - self.do_analyze(x, expr_to_alias, parse_local_name) + self.do_analyze(x, config_context, expr_to_alias, parse_local_name) for x in expr.children ], False, ) if isinstance(expr, TableFunctionExpression): - return self.table_function_expression_extractor(expr, expr_to_alias) + return self.table_function_expression_extractor( + expr, expr_to_alias, config_context + ) if isinstance(expr, TableFunctionPartitionSpecDefinition): return table_function_partition_spec( expr.over, [ - self.do_analyze(x, expr_to_alias, parse_local_name) + self.do_analyze(x, config_context, expr_to_alias, parse_local_name) for x in expr.partition_spec ] if expr.partition_spec else [], [ - self.do_analyze(x, expr_to_alias, parse_local_name) + self.do_analyze(x, config_context, expr_to_alias, parse_local_name) for x in expr.order_spec ] if expr.order_spec @@ -411,13 +451,16 @@ def do_analyze( return self.unary_expression_extractor( expr, expr_to_alias, + config_context, parse_local_name, keep_alias=keep_alias, ) if isinstance(expr, SortOrder): return order_expression( - self.do_analyze(expr.child, expr_to_alias, parse_local_name), + self.do_analyze( + expr.child, config_context, expr_to_alias, parse_local_name + ), expr.direction.sql, expr.null_ordering.sql, ) @@ -428,47 +471,60 @@ def do_analyze( if isinstance(expr, WithinGroup): return within_group_expression( - self.do_analyze(expr.expr, expr_to_alias, parse_local_name), - [self.do_analyze(e, expr_to_alias) for e in expr.order_by_cols], + self.do_analyze( + expr.expr, config_context, expr_to_alias, parse_local_name + ), + [ + self.do_analyze(e, config_context, expr_to_alias) + for e in expr.order_by_cols + ], ) if isinstance(expr, BinaryExpression): return self.binary_operator_extractor( expr, expr_to_alias, + config_context, parse_local_name, ) if isinstance(expr, InsertMergeExpression): return insert_merge_statement( - self.do_analyze(expr.condition, expr_to_alias) + self.do_analyze(expr.condition, config_context, expr_to_alias) if expr.condition else None, - [self.do_analyze(k, expr_to_alias) for k in expr.keys], - [self.do_analyze(v, expr_to_alias) for v in expr.values], + [self.do_analyze(k, config_context, expr_to_alias) for k in expr.keys], + [ + self.do_analyze(v, config_context, expr_to_alias) + for v in expr.values + ], ) if isinstance(expr, UpdateMergeExpression): return update_merge_statement( - self.do_analyze(expr.condition, expr_to_alias) + self.do_analyze(expr.condition, config_context, expr_to_alias) if expr.condition else None, { - self.do_analyze(k, expr_to_alias): self.do_analyze(v, expr_to_alias) + self.do_analyze(k, config_context, expr_to_alias): self.do_analyze( + v, config_context, expr_to_alias + ) for k, v in expr.assignments.items() }, ) if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( - self.do_analyze(expr.condition, expr_to_alias) + self.do_analyze(expr.condition, config_context, expr_to_alias) if expr.condition else None ) if isinstance(expr, ListAgg): return list_agg( - self.do_analyze(expr.col, expr_to_alias, parse_local_name), + self.do_analyze( + expr.col, config_context, expr_to_alias, parse_local_name + ), str_to_sql(expr.delimiter), expr.is_distinct, ) @@ -476,7 +532,9 @@ def do_analyze( if isinstance(expr, ColumnSum): return column_sum( [ - self.do_analyze(col, expr_to_alias, parse_local_name) + self.do_analyze( + col, config_context, expr_to_alias, parse_local_name + ) for col in expr.exprs ] ) @@ -484,9 +542,13 @@ def do_analyze( if isinstance(expr, RankRelatedFunctionExpression): return rank_related_function_expression( expr.sql, - self.do_analyze(expr.expr, expr_to_alias, parse_local_name), + self.do_analyze( + expr.expr, config_context, expr_to_alias, parse_local_name + ), expr.offset, - self.do_analyze(expr.default, expr_to_alias, parse_local_name) + self.do_analyze( + expr.default, config_context, expr_to_alias, parse_local_name + ) if expr.default else None, expr.ignore_nulls, @@ -500,11 +562,14 @@ def table_function_expression_extractor( self, expr: TableFunctionExpression, expr_to_alias: Dict[str, str], + config_context: ConfigContext, parse_local_name=False, ) -> str: if isinstance(expr, FlattenFunction): return flatten_expression( - self.do_analyze(expr.input, expr_to_alias, parse_local_name), + self.do_analyze( + expr.input, config_context, expr_to_alias, parse_local_name + ), expr.path, expr.outer, expr.recursive, @@ -514,7 +579,7 @@ def table_function_expression_extractor( sql = function_expression( expr.func_name, [ - self.do_analyze(x, expr_to_alias, parse_local_name) + self.do_analyze(x, config_context, expr_to_alias, parse_local_name) for x in expr.args ], False, @@ -523,7 +588,9 @@ def table_function_expression_extractor( sql = named_arguments_function( expr.func_name, { - key: self.do_analyze(value, expr_to_alias, parse_local_name) + key: self.do_analyze( + value, config_context, expr_to_alias, parse_local_name + ) for key, value in expr.args.items() }, ) @@ -533,7 +600,7 @@ def table_function_expression_extractor( "NamedArgumentsTableFunction, GeneratorTableFunction, or FlattenFunction." ) partition_spec_sql = ( - self.do_analyze(expr.partition_spec, expr_to_alias) + self.do_analyze(expr.partition_spec, config_context, expr_to_alias) if expr.partition_spec else "" ) @@ -543,6 +610,7 @@ def unary_expression_extractor( self, expr: UnaryExpression, expr_to_alias: Dict[str, str], + config_context: ConfigContext, parse_local_name=False, keep_alias=True, ) -> str: @@ -554,7 +622,9 @@ def unary_expression_extractor( if v == expr.child.name: expr_to_alias[k] = quoted_name alias_exp = alias_expression( - self.do_analyze(expr.child, expr_to_alias, parse_local_name), + self.do_analyze( + expr.child, config_context, expr_to_alias, parse_local_name + ), quoted_name, ) @@ -562,19 +632,25 @@ def unary_expression_extractor( expr_str = expr_str.upper() if parse_local_name else expr_str return expr_str if isinstance(expr, UnresolvedAlias): - expr_str = self.do_analyze(expr.child, expr_to_alias, parse_local_name) + expr_str = self.do_analyze( + expr.child, config_context, expr_to_alias, parse_local_name + ) if parse_local_name: expr_str = expr_str.upper() return quote_name(expr_str.strip()) elif isinstance(expr, Cast): return cast_expression( - self.do_analyze(expr.child, expr_to_alias, parse_local_name), + self.do_analyze( + expr.child, config_context, expr_to_alias, parse_local_name + ), expr.to, expr.try_, ) else: return unary_expression( - self.do_analyze(expr.child, expr_to_alias, parse_local_name), + self.do_analyze( + expr.child, config_context, expr_to_alias, parse_local_name + ), expr.sql_operator, expr.operator_first, ) @@ -583,21 +659,25 @@ def binary_operator_extractor( self, expr: BinaryExpression, expr_to_alias: Dict[str, str], + config_context: ConfigContext, parse_local_name=False, ) -> str: - if self.config_context.eliminate_numeric_sql_value_cast_enabled: + if config_context.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( - expr.left, expr_to_alias, parse_local_name + expr.left, expr_to_alias, config_context, parse_local_name ) right_sql_expr = self.to_sql_try_avoid_cast( expr.right, expr_to_alias, + config_context, parse_local_name, ) else: - left_sql_expr = self.do_analyze(expr.left, expr_to_alias, parse_local_name) + left_sql_expr = self.do_analyze( + expr.left, config_context, expr_to_alias, parse_local_name + ) right_sql_expr = self.do_analyze( - expr.right, expr_to_alias, parse_local_name + expr.right, config_context, expr_to_alias, parse_local_name ) operator = expr.sql_operator.lower() @@ -618,7 +698,10 @@ def binary_operator_extractor( ) def grouping_extractor( - self, expr: GroupingSet, expr_to_alias: Dict[str, str] + self, + expr: GroupingSet, + expr_to_alias: Dict[str, str], + config_context: ConfigContext, ) -> str: return self.do_analyze( FunctionExpression( @@ -626,6 +709,7 @@ def grouping_extractor( [c.child if isinstance(c, Alias) else c for c in expr.children], False, ), + config_context, expr_to_alias, ) @@ -637,29 +721,43 @@ def window_frame_boundary(self, offset: str) -> str: return offset def to_sql_try_avoid_cast( - self, expr: Expression, expr_to_alias: Dict[str, str], parse_local_name=False + self, + expr: Expression, + expr_to_alias: Dict[str, str], + config_context: ConfigContext, + parse_local_name=False, ) -> str: # if expression is a numeric literal, return the number without casting, # otherwise process as normal if isinstance(expr, Literal) and isinstance(expr.datatype, _NumericType): return numeric_to_sql_without_cast(expr.value, expr.datatype) else: - return self.do_analyze(expr, expr_to_alias, parse_local_name) + return self.do_analyze( + expr, config_context, expr_to_alias, parse_local_name + ) def resolve( - self, logical_plan: LogicalPlan, expr_to_alias: Optional[Dict[str, str]] = None + self, + logical_plan: LogicalPlan, + expr_to_alias: Optional[Dict[str, str]] = None, + config_context: Optional[ConfigContext] = None, ) -> MockExecutionPlan: self.subquery_plans = [] if expr_to_alias is None: expr_to_alias = {} - with self.config_context: - result = self.do_resolve(logical_plan, expr_to_alias) + if config_context is None: + config_context = ConfigContext(self.session) + + result = self.do_resolve(logical_plan, expr_to_alias, config_context) return result def do_resolve( - self, logical_plan: LogicalPlan, expr_to_alias: Dict[str, str] + self, + logical_plan: LogicalPlan, + expr_to_alias: Dict[str, str], + config_context: ConfigContext, ) -> MockExecutionPlan: resolved_children = {} expr_to_alias_maps = {} @@ -679,7 +777,7 @@ def do_resolve( expr_to_alias.update({p: q for p, q in v.items() if counts[p] < 2}) return self.do_resolve_with_resolved_children( - logical_plan, resolved_children, expr_to_alias + logical_plan, resolved_children, expr_to_alias, config_context ) def do_resolve_with_resolved_children( @@ -687,6 +785,7 @@ def do_resolve_with_resolved_children( logical_plan: LogicalPlan, resolved_children: Dict[LogicalPlan, SnowflakePlan], expr_to_alias: Dict[str, str], + config_context: ConfigContext, ) -> MockExecutionPlan: if isinstance(logical_plan, MockExecutionPlan): return logical_plan @@ -774,8 +873,12 @@ def do_resolve_with_resolved_children( logical_plan.child, SnowflakePlan ) and isinstance(logical_plan.child.source_plan, Sort) return self.plan_builder.limit( - self.to_sql_try_avoid_cast(logical_plan.limit_expr, expr_to_alias), - self.to_sql_try_avoid_cast(logical_plan.offset_expr, expr_to_alias), + self.to_sql_try_avoid_cast( + logical_plan.limit_expr, expr_to_alias, config_context + ), + self.to_sql_try_avoid_cast( + logical_plan.offset_expr, expr_to_alias, config_context + ), resolved_children[logical_plan.child], on_top_of_order_by, logical_plan, @@ -804,7 +907,9 @@ def do_resolve_with_resolved_children( query=resolved_children[logical_plan.child], stage_location=logical_plan.stage_location, source_plan=logical_plan, - partition_by=self.do_analyze(logical_plan.partition_by, expr_to_alias) + partition_by=self.do_analyze( + logical_plan.partition_by, config_context, expr_to_alias + ) if logical_plan.partition_by else None, file_format_name=logical_plan.file_format_name, diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index de7ef7c5099..45326a2cf82 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -11,7 +11,7 @@ import pytest -from snowflake.snowpark._internal.analyzer.snowflake_plan import ConfigContext +from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType @@ -500,12 +500,6 @@ def finish(self): def test_config_context(session): - config_context = ConfigContext(session) - - # Check if all context configs are present in the session - for name in config_context.configs: - assert hasattr(session, name) - try: original_cte_optimization = session.cte_optimization_enabled original_eliminate_numeric_sql_value_cast_enabled = ( @@ -514,37 +508,30 @@ def test_config_context(session): original_query_compilation_stage_enabled = ( session._query_compilation_stage_enabled ) - with config_context: - # Active context - session.cte_optimization_enabled = not original_cte_optimization - session.eliminate_numeric_sql_value_cast_enabled = ( - not original_eliminate_numeric_sql_value_cast_enabled - ) - session._query_compilation_stage_enabled = ( - not original_query_compilation_stage_enabled - ) + config_context = ConfigContext(session) - assert config_context.cte_optimization_enabled == original_cte_optimization - assert ( - config_context.eliminate_numeric_sql_value_cast_enabled - == original_eliminate_numeric_sql_value_cast_enabled - ) - assert ( - config_context._query_compilation_stage_enabled - == original_query_compilation_stage_enabled - ) + # Check if all context configs are present in the session + for name in config_context.configs: + assert hasattr(session, name) - # Context is deactivated - assert ( - config_context.cte_optimization_enabled == session.cte_optimization_enabled + # change session configs + session.cte_optimization_enabled = not original_cte_optimization + session.eliminate_numeric_sql_value_cast_enabled = ( + not original_eliminate_numeric_sql_value_cast_enabled ) + session._query_compilation_stage_enabled = ( + not original_query_compilation_stage_enabled + ) + + # assert we read original config values + assert config_context.cte_optimization_enabled == original_cte_optimization assert ( config_context.eliminate_numeric_sql_value_cast_enabled - == session.eliminate_numeric_sql_value_cast_enabled + == original_eliminate_numeric_sql_value_cast_enabled ) assert ( config_context._query_compilation_stage_enabled - == session._query_compilation_stage_enabled + == original_query_compilation_stage_enabled ) with pytest.raises( From a85a144a311d9692cee1f90da6356025f4d027ae Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 25 Sep 2024 16:10:35 -0700 Subject: [PATCH 37/62] add config context to repeated subquery elimination resolution stage --- .../snowpark/_internal/compiler/plan_compiler.py | 4 +++- .../snowpark/_internal/compiler/query_generator.py | 6 ++++-- .../compiler/repeated_subquery_elimination.py | 11 +++++++++-- tests/integ/compiler/test_query_generator.py | 12 ++++++++---- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 71b6ce16b97..b132861a406 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -120,7 +120,9 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: ] # 4. do a final pass of code generation - queries = query_generator.generate_queries(logical_plans) + queries = query_generator.generate_queries( + logical_plans, self.config_context + ) # log telemetry data deep_copy_time = deep_copy_end_time - start_time diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index dd6e56bce76..98920c8b1a7 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -68,7 +68,7 @@ def __init__( self.resolved_with_query_block: Dict[str, Query] = {} def generate_queries( - self, logical_plans: List[LogicalPlan] + self, logical_plans: List[LogicalPlan], config_context: ConfigContext ) -> Dict[PlanQueryType, List[Query]]: """ Generate final queries for the given set of logical plans. @@ -82,7 +82,9 @@ def generate_queries( ) # generate queries for each logical plan - snowflake_plans = [self.resolve(logical_plan) for logical_plan in logical_plans] + snowflake_plans = [ + self.resolve(logical_plan, config_context) for logical_plan in logical_plans + ] # merge all results into final set of queries queries = [] post_actions = [] diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index 38e3b72a32b..9a7e36ee9fb 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Set +from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( LogicalPlan, @@ -49,9 +50,11 @@ def __init__( self, logical_plans: List[LogicalPlan], query_generator: QueryGenerator, + config_context: ConfigContext, ) -> None: self._logical_plans = logical_plans self._query_generator = query_generator + self._config_context = config_context def apply(self) -> List[LogicalPlan]: """ @@ -67,7 +70,9 @@ def apply(self) -> List[LogicalPlan]: # do a pass of resolve of the logical plan to make sure we get a valid # resolved plan to start the process. # If the plan is already a resolved plan, this step will be a no-op. - logical_plan = self._query_generator.resolve(logical_plan) + logical_plan = self._query_generator.resolve( + logical_plan, self._config_context + ) # apply the CTE optimization on the resolved plan duplicated_nodes, node_parents_map = find_duplicate_subtrees(logical_plan) @@ -139,7 +144,9 @@ def _update_parents( ) with_block._is_valid_for_replacement = True - resolved_with_block = self._query_generator.resolve(with_block) + resolved_with_block = self._query_generator.resolve( + with_block, self._config_context + ) _update_parents( node, should_replace_child=True, new_child=resolved_with_block ) diff --git a/tests/integ/compiler/test_query_generator.py b/tests/integ/compiler/test_query_generator.py index 5ce4c005ad3..cb203487415 100644 --- a/tests/integ/compiler/test_query_generator.py +++ b/tests/integ/compiler/test_query_generator.py @@ -89,7 +89,7 @@ def check_generated_plan_queries(plan: SnowflakePlan) -> None: assert plan.queries is None assert plan.post_actions is None # regenerate the queries - plan_queries = query_generator.generate_queries([source_plan]) + plan_queries = query_generator.generate_queries([source_plan], config_context=None) queries = [query.sql for query in plan_queries[PlanQueryType.QUERIES]] post_actions = [query.sql for query in plan_queries[PlanQueryType.POST_ACTIONS]] assert queries == original_queries @@ -191,7 +191,7 @@ def test_table_create_from_large_query_breakdown(session, plan_source_generator) comment=None, ) - queries = generator.generate_queries([create_table_source]) + queries = generator.generate_queries([create_table_source], config_context=None) assert len(queries[PlanQueryType.QUERIES]) == 1 assert len(queries[PlanQueryType.POST_ACTIONS]) == 0 @@ -321,7 +321,9 @@ def verify_multiple_create_queries( # reset the whole plan reset_plan_tree(df._plan) # regenerate the queries - plan_queries = query_generator.generate_queries([df._plan.source_plan]) + plan_queries = query_generator.generate_queries( + [df._plan.source_plan], config_context=None + ) queries = [query.sql.lstrip() for query in plan_queries[PlanQueryType.QUERIES]] post_actions = [ query.sql.lstrip() for query in plan_queries[PlanQueryType.POST_ACTIONS] @@ -362,7 +364,9 @@ def test_multiple_plan_query_generation(session): reset_plan_tree(snowflake_plan) reset_plan_tree(df_res._plan) logical_plans = [snowflake_plan.source_plan, df_res._plan.source_plan] - generated_queries = query_generator.generate_queries(logical_plans) + generated_queries = query_generator.generate_queries( + logical_plans, config_context=None + ) result_queries = [ query.sql.lstrip() for query in generated_queries[PlanQueryType.QUERIES] ] From a79ffb41852baa34e104692bfded2cb47b90f268 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 25 Sep 2024 22:17:42 -0700 Subject: [PATCH 38/62] fix tests --- src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py | 2 +- src/snowflake/snowpark/_internal/compiler/plan_compiler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 34fad4a05ad..5559efc23c8 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -322,7 +322,7 @@ def replace_repeated_subquery_with_cte( # the common subquery elimination will be done through the new plan transformation. if ( not config_context.cte_optimization_enabled - or config_context.query_compilation_stage_enabled + or config_context._query_compilation_stage_enabled ): return self diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index b132861a406..178afa5a941 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -93,7 +93,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: cte_start_time = time.time() if self.config_context.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( - logical_plans, query_generator + logical_plans, query_generator, self.config_context ) logical_plans = repeated_subquery_eliminator.apply() From 4420350ec0f44f1b1c399ab16247478591d930af Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 25 Sep 2024 22:19:54 -0700 Subject: [PATCH 39/62] refactor --- .../snowpark/_internal/analyzer/analyzer.py | 2 +- .../_internal/analyzer/snowflake_plan.py | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 4de2700f2d0..7dbff4aa64d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -914,7 +914,7 @@ def resolve( self.generated_alias_maps = {} if config_context is None: config_context = ConfigContext(self.session) - self.plan_builder.config_context = config_context + self.plan_builder.set_config_context(config_context) result = self.do_resolve(logical_plan, config_context) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 5559efc23c8..36c213b0530 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -542,8 +542,10 @@ def __init__( # on the optimized plan. During the final query generation, no schema query is needed, # this helps reduces un-necessary overhead for the describing call. self._skip_schema_query = skip_schema_query - # TODO: SNOW-1541096 remove after old cte implementation is removed - self.config_context: Optional[ConfigContext] = None + self._config_context: Optional[ConfigContext] = None + + def set_config_context(self, config_context: ConfigContext) -> None: + self._config_context = config_context @SnowflakePlan.Decorator.wrap_exception def build( @@ -577,7 +579,7 @@ def build( ), "No schema query is available in child SnowflakePlan" new_schema_query = schema_query or sql_generator(child.schema_query) - config_context = self.config_context or ConfigContext(self.session) + config_context = self._config_context or ConfigContext(self.session) placeholder_query = ( sql_generator(select_child._id) if config_context.cte_optimization_enabled and select_child._id is not None @@ -617,7 +619,7 @@ def build_binary( right_schema_query = schema_value_statement(select_right.attributes) schema_query = sql_generator(left_schema_query, right_schema_query) - config_context = self.config_context or ConfigContext(self.session) + config_context = self._config_context or ConfigContext(self.session) placeholder_query = ( sql_generator(select_left._id, select_right._id) if config_context.cte_optimization_enabled @@ -944,7 +946,7 @@ def save_as_table( ) child = child.replace_repeated_subquery_with_cte( - self.config_context or ConfigContext(self.session) + self._config_context or ConfigContext(self.session) ) def get_create_table_as_select_plan(child: SnowflakePlan, replace, error): @@ -1134,7 +1136,7 @@ def create_or_replace_view( raise SnowparkClientExceptionMessages.PLAN_CREATE_VIEWS_FROM_SELECT_ONLY() child = child.replace_repeated_subquery_with_cte( - self.config_context or ConfigContext(self.session) + self._config_context or ConfigContext(self.session) ) return self.build( lambda x: create_or_replace_view_statement(name, x, is_temp, comment), @@ -1179,7 +1181,7 @@ def create_or_replace_dynamic_table( raise ValueError(f"Unknown create mode: {create_mode}") # pragma: no cover child = child.replace_repeated_subquery_with_cte( - self.config_context or ConfigContext(self.session) + self._config_context or ConfigContext(self.session) ) return self.build( lambda x: create_or_replace_dynamic_table_statement( @@ -1484,7 +1486,7 @@ def copy_into_location( **copy_options: Optional[Any], ) -> SnowflakePlan: query = query.replace_repeated_subquery_with_cte( - self.config_context or ConfigContext(self.session) + self._config_context or ConfigContext(self.session) ) return self.build( lambda x: copy_into_location( @@ -1513,7 +1515,7 @@ def update( ) -> SnowflakePlan: if source_data: source_data = source_data.replace_repeated_subquery_with_cte( - self.config_context or ConfigContext(self.session) + self._config_context or ConfigContext(self.session) ) return self.build( lambda x: update_statement( @@ -1546,7 +1548,7 @@ def delete( ) -> SnowflakePlan: if source_data: source_data = source_data.replace_repeated_subquery_with_cte( - self.config_context or ConfigContext(self.session) + self._config_context or ConfigContext(self.session) ) return self.build( lambda x: delete_statement( @@ -1577,7 +1579,7 @@ def merge( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: source_data = source_data.replace_repeated_subquery_with_cte( - self.config_context or ConfigContext(self.session) + self._config_context or ConfigContext(self.session) ) return self.build( lambda x: merge_statement(table_name, x, join_expr, clauses), From 5f1eaa6da6c97eff4f4186becbdd70d08c94d53f Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 13:51:48 -0700 Subject: [PATCH 40/62] remove do_analyze --- .../snowpark/_internal/analyzer/analyzer.py | 278 ++++++++------- src/snowflake/snowpark/mock/_analyzer.py | 319 ++++++++++++------ 2 files changed, 345 insertions(+), 252 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 7dbff4aa64d..6f85552a5f4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -173,29 +173,21 @@ def analyze( expr: Union[Expression, NamedExpression], df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], parse_local_name=False, + config_context: Optional[ConfigContext] = None, ) -> str: # Set the config context for analysis step - config_context = ConfigContext(self.session) - return self.do_analyze( - expr, df_aliased_col_name_to_real_col_name, config_context, parse_local_name - ) + if config_context is None: + config_context = ConfigContext(self.session) - def do_analyze( - self, - expr: Union[Expression, NamedExpression], - df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], - config_context: ConfigContext, - parse_local_name=False, - ) -> str: if isinstance(expr, GroupingSetsExpression): return grouping_set_expression( [ [ - self.do_analyze( + self.analyze( a, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) for a in arg ] @@ -205,39 +197,39 @@ def do_analyze( if isinstance(expr, Like): return like_expression( - self.do_analyze( + self.analyze( expr.expr, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), - self.do_analyze( + self.analyze( expr.pattern, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), ) if isinstance(expr, RegExp): return regexp_expression( - self.do_analyze( + self.analyze( expr.expr, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), - self.do_analyze( + self.analyze( expr.pattern, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), - self.do_analyze( + self.analyze( expr.parameters, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) if expr.parameters is not None else None, @@ -248,11 +240,11 @@ def do_analyze( expr.collation_spec.upper() if parse_local_name else expr.collation_spec ) return collate_expression( - self.do_analyze( + self.analyze( expr.expr, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), collation_spec, ) @@ -262,11 +254,11 @@ def do_analyze( if parse_local_name and isinstance(field, str): field = field.upper() return subfield_expression( - self.do_analyze( + self.analyze( expr.expr, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), field, ) @@ -275,26 +267,26 @@ def do_analyze( return case_when_expression( [ ( - self.do_analyze( + self.analyze( condition, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), - self.do_analyze( + self.analyze( value, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), ) for condition, value in expr.branches ], - self.do_analyze( + self.analyze( expr.else_value, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) if expr.else_value else "NULL", @@ -307,15 +299,15 @@ def do_analyze( resolved_expr = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) else: - resolved_expr = self.do_analyze( + resolved_expr = self.analyze( expression, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) block_expressions.append(resolved_expr) @@ -328,24 +320,24 @@ def do_analyze( in_value = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) else: - in_value = self.do_analyze( + in_value = self.analyze( expression, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) in_values.append(in_value) return in_expression( - self.do_analyze( + self.analyze( expr.columns, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), in_values, ) @@ -357,44 +349,44 @@ def do_analyze( if isinstance(expr, WindowExpression): return window_expression( - self.do_analyze( + self.analyze( expr.window_function, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), - self.do_analyze( + self.analyze( expr.window_spec, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), ) if isinstance(expr, WindowSpecDefinition): return window_spec_expression( [ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) for x in expr.partition_spec ], [ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) for x in expr.order_spec ], - self.do_analyze( + self.analyze( expr.frame_spec, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), ) if isinstance(expr, SpecifiedWindowFrame): @@ -474,7 +466,7 @@ def do_analyze( # This case is hit by df.col("*") return ",".join( [ - self.do_analyze( + self.analyze( e, df_aliased_col_name_to_real_col_name, config_context ) for e in expr.expressions @@ -490,11 +482,11 @@ def do_analyze( return function_expression( func_name, [ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) for x in expr.children ], @@ -514,22 +506,22 @@ def do_analyze( return table_function_partition_spec( expr.over, [ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) for x in expr.partition_spec ] if expr.partition_spec else [], [ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) for x in expr.order_spec ] @@ -547,11 +539,11 @@ def do_analyze( if isinstance(expr, SortOrder): return order_expression( - self.do_analyze( + self.analyze( expr.child, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), expr.direction.sql, expr.null_ordering.sql, @@ -563,14 +555,14 @@ def do_analyze( if isinstance(expr, WithinGroup): return within_group_expression( - self.do_analyze( + self.analyze( expr.expr, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), [ - self.do_analyze( + self.analyze( e, df_aliased_col_name_to_real_col_name, config_context ) for e in expr.order_by_cols @@ -587,19 +579,19 @@ def do_analyze( if isinstance(expr, InsertMergeExpression): return insert_merge_statement( - self.do_analyze( + self.analyze( expr.condition, df_aliased_col_name_to_real_col_name, config_context ) if expr.condition else None, [ - self.do_analyze( + self.analyze( k, df_aliased_col_name_to_real_col_name, config_context ) for k in expr.keys ], [ - self.do_analyze( + self.analyze( v, df_aliased_col_name_to_real_col_name, config_context ) for v in expr.values @@ -608,15 +600,15 @@ def do_analyze( if isinstance(expr, UpdateMergeExpression): return update_merge_statement( - self.do_analyze( + self.analyze( expr.condition, df_aliased_col_name_to_real_col_name, config_context ) if expr.condition else None, { - self.do_analyze( + self.analyze( k, df_aliased_col_name_to_real_col_name, config_context - ): self.do_analyze( + ): self.analyze( v, df_aliased_col_name_to_real_col_name, config_context ) for k, v in expr.assignments.items() @@ -625,7 +617,7 @@ def do_analyze( if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( - self.do_analyze( + self.analyze( expr.condition, df_aliased_col_name_to_real_col_name, config_context ) if expr.condition @@ -634,11 +626,11 @@ def do_analyze( if isinstance(expr, ListAgg): return list_agg( - self.do_analyze( + self.analyze( expr.col, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), str_to_sql(expr.delimiter), expr.is_distinct, @@ -647,11 +639,11 @@ def do_analyze( if isinstance(expr, ColumnSum): return column_sum( [ - self.do_analyze( + self.analyze( col, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) for col in expr.exprs ] @@ -660,18 +652,18 @@ def do_analyze( if isinstance(expr, RankRelatedFunctionExpression): return rank_related_function_expression( expr.sql, - self.do_analyze( + self.analyze( expr.expr, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), expr.offset, - self.do_analyze( + self.analyze( expr.default, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) if expr.default else None, @@ -691,11 +683,11 @@ def table_function_expression_extractor( ) -> str: if isinstance(expr, FlattenFunction): return flatten_expression( - self.do_analyze( + self.analyze( expr.input, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), expr.path, expr.outer, @@ -706,11 +698,11 @@ def table_function_expression_extractor( sql = function_expression( expr.func_name, [ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) for x in expr.args ], @@ -735,10 +727,10 @@ def table_function_expression_extractor( "NamedArgumentsTableFunction, GeneratorTableFunction, or FlattenFunction." ) partition_spec_sql = ( - self.do_analyze( + self.analyze( expr.partition_spec, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) if expr.partition_spec else "" @@ -766,42 +758,42 @@ def unary_expression_extractor( if v == expr.child.name: df_alias_dict[k] = quoted_name return alias_expression( - self.do_analyze( + self.analyze( expr.child, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), quoted_name, ) if isinstance(expr, UnresolvedAlias): - expr_str = self.do_analyze( + expr_str = self.analyze( expr.child, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) if parse_local_name: expr_str = expr_str.upper() return expr_str elif isinstance(expr, Cast): return cast_expression( - self.do_analyze( + self.analyze( expr.child, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), expr.to, expr.try_, ) else: return unary_expression( - self.do_analyze( + self.analyze( expr.child, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ), expr.sql_operator, expr.operator_first, @@ -828,17 +820,17 @@ def binary_operator_extractor( parse_local_name, ) else: - left_sql_expr = self.do_analyze( + left_sql_expr = self.analyze( expr.left, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) - right_sql_expr = self.do_analyze( + right_sql_expr = self.analyze( expr.right, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) if isinstance(expr, BinaryArithmeticExpression): return binary_arithmetic_expression( @@ -859,14 +851,14 @@ def binary_operator_extractor( def grouping_extractor( self, expr: GroupingSet, df_aliased_col_name_to_real_col_name, config_context ) -> str: - return self.do_analyze( + return self.analyze( FunctionExpression( expr.pretty_name.upper(), [c.child if isinstance(c, Alias) else c for c in expr.children], False, ), df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) def window_frame_boundary(self, offset: str) -> str: @@ -900,11 +892,11 @@ def to_sql_try_avoid_cast( ): return str(expr.value).upper() else: - return self.do_analyze( + return self.analyze( expr, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, + config_context=config_context, ) def resolve( @@ -965,7 +957,7 @@ def do_resolve( logical_plan, resolved_children, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) res.df_aliased_col_name_to_real_col_name.update( df_aliased_col_name_to_real_col_name @@ -984,10 +976,10 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableFunctionJoin): return self.plan_builder.join_table_function( - self.do_analyze( + self.analyze( logical_plan.table_function, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ), resolved_children[logical_plan.children[0]], logical_plan, @@ -998,20 +990,20 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableFunctionRelation): return self.plan_builder.from_table_function( - self.do_analyze( + self.analyze( logical_plan.table_function, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ), logical_plan, ) if isinstance(logical_plan, Lateral): return self.plan_builder.lateral( - self.do_analyze( + self.analyze( logical_plan.table_function, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ), resolved_children[logical_plan.children[0]], logical_plan, @@ -1026,7 +1018,7 @@ def do_resolve_with_resolved_children( for expr in logical_plan.grouping_expressions ], [ - self.do_analyze( + self.analyze( expr, df_aliased_col_name_to_real_col_name, config_context ) for expr in logical_plan.aggregate_expressions @@ -1039,7 +1031,7 @@ def do_resolve_with_resolved_children( return self.plan_builder.project( list( map( - lambda x: self.do_analyze( + lambda x: self.analyze( x, df_aliased_col_name_to_real_col_name, config_context ), logical_plan.project_list, @@ -1051,10 +1043,10 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Filter): return self.plan_builder.filter( - self.do_analyze( + self.analyze( logical_plan.condition, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ), resolved_children[logical_plan.child], logical_plan, @@ -1071,19 +1063,19 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Join): join_condition = ( - self.do_analyze( + self.analyze( logical_plan.join_condition, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) if logical_plan.join_condition else "" ) match_condition = ( - self.do_analyze( + self.analyze( logical_plan.match_condition, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) if logical_plan.match_condition else "" @@ -1101,7 +1093,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Sort): return self.plan_builder.sort( [ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, config_context ) for x in logical_plan.order @@ -1167,7 +1159,7 @@ def do_resolve_with_resolved_children( mode=logical_plan.mode, table_type=logical_plan.table_type, clustering_keys=[ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, config_context ) for x in logical_plan.clustering_exprs @@ -1194,12 +1186,12 @@ def do_resolve_with_resolved_children( self.to_sql_try_avoid_cast( logical_plan.limit_expr, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ), self.to_sql_try_avoid_cast( logical_plan.offset_expr, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ), resolved_children[logical_plan.child], on_top_of_order_by, @@ -1227,7 +1219,7 @@ def do_resolve_with_resolved_children( ] child = self.plan_builder.project( [ - self.do_analyze( + self.analyze( col, df_aliased_col_name_to_real_col_name, config_context ) for col in project_exprs @@ -1245,36 +1237,36 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan.pivot_values, List): pivot_values = [ - self.do_analyze( + self.analyze( pv, df_aliased_col_name_to_real_col_name, config_context ) for pv in logical_plan.pivot_values ] elif isinstance(logical_plan.pivot_values, ScalarSubquery): - pivot_values = self.do_analyze( + pivot_values = self.analyze( logical_plan.pivot_values, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) else: pivot_values = None pivot_plan = self.plan_builder.pivot( - self.do_analyze( + self.analyze( logical_plan.pivot_column, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ), pivot_values, - self.do_analyze( + self.analyze( logical_plan.aggregates[0], df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ), - self.do_analyze( + self.analyze( logical_plan.default_on_null, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) if logical_plan.default_on_null else None, @@ -1299,7 +1291,7 @@ def do_resolve_with_resolved_children( logical_plan.value_column, logical_plan.name_column, [ - self.do_analyze( + self.analyze( c, df_aliased_col_name_to_real_col_name, config_context ) for c in logical_plan.column_list @@ -1343,7 +1335,7 @@ def do_resolve_with_resolved_children( refresh_mode=logical_plan.refresh_mode, initialize=logical_plan.initialize, clustering_keys=[ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, config_context ) for x in logical_plan.clustering_exprs @@ -1377,7 +1369,7 @@ def do_resolve_with_resolved_children( validation_mode=logical_plan.validation_mode, column_names=logical_plan.column_names, transformations=[ - self.do_analyze( + self.analyze( x, df_aliased_col_name_to_real_col_name, config_context ) for x in logical_plan.transformations @@ -1394,10 +1386,10 @@ def do_resolve_with_resolved_children( query=resolved_children[logical_plan.child], stage_location=logical_plan.stage_location, source_plan=logical_plan, - partition_by=self.do_analyze( + partition_by=self.analyze( logical_plan.partition_by, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) if logical_plan.partition_by else None, @@ -1412,17 +1404,17 @@ def do_resolve_with_resolved_children( return self.plan_builder.update( logical_plan.table_name, { - self.do_analyze( + self.analyze( k, df_aliased_col_name_to_real_col_name, config_context - ): self.do_analyze( + ): self.analyze( v, df_aliased_col_name_to_real_col_name, config_context ) for k, v in logical_plan.assignments.items() }, - self.do_analyze( + self.analyze( logical_plan.condition, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) if logical_plan.condition else None, @@ -1435,10 +1427,10 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableDelete): return self.plan_builder.delete( logical_plan.table_name, - self.do_analyze( + self.analyze( logical_plan.condition, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ) if logical_plan.condition else None, @@ -1455,13 +1447,13 @@ def do_resolve_with_resolved_children( resolved_children[logical_plan.source] if logical_plan.source else logical_plan.source, - self.do_analyze( + self.analyze( logical_plan.join_expr, df_aliased_col_name_to_real_col_name, - config_context, + config_context=config_context, ), [ - self.do_analyze( + self.analyze( c, df_aliased_col_name_to_real_col_name, config_context ) for c in logical_plan.clauses diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index f26fa786126..dce1b13dc18 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -161,23 +161,7 @@ def analyze( expr_to_alias: Optional[Dict[str, str]] = None, parse_local_name=False, keep_alias=True, - ) -> Union[str, List[str]]: - config_context = ConfigContext(self.session) - return self.do_analyze( - expr, - config_context, - expr_to_alias, - parse_local_name, - keep_alias, - ) - - def do_analyze( - self, - expr: Union[Expression, NamedExpression], - config_context: ConfigContext, - expr_to_alias: Optional[Dict[str, str]] = None, - parse_local_name=False, - keep_alias=True, + config_context: Optional[ConfigContext] = None, ) -> Union[str, List[str]]: """ Args: @@ -190,6 +174,8 @@ def do_analyze( """ if expr_to_alias is None: expr_to_alias = {} + if config_context is None: + config_context = ConfigContext(self.session) if isinstance(expr, GroupingSetsExpression): self._conn.log_not_supported_error( external_feature_name="DataFrame.group_by_grouping_sets", @@ -198,24 +184,39 @@ def do_analyze( if isinstance(expr, Like): return like_expression( - self.do_analyze( - expr.expr, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.expr, + expr_to_alias, + parse_local_name, + config_context=config_context, ), - self.do_analyze( - expr.pattern, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.pattern, + expr_to_alias, + parse_local_name, + config_context=config_context, ), ) if isinstance(expr, RegExp): return regexp_expression( - self.do_analyze( - expr.expr, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.expr, + expr_to_alias, + parse_local_name, + config_context=config_context, ), - self.do_analyze( - expr.pattern, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.pattern, + expr_to_alias, + parse_local_name, + config_context=config_context, ), - self.do_analyze( - expr.parameters, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.parameters, + expr_to_alias, + parse_local_name, + config_context=config_context, ) if expr.parameters is not None else None, @@ -226,8 +227,11 @@ def do_analyze( expr.collation_spec.upper() if parse_local_name else expr.collation_spec ) return collate_expression( - self.do_analyze( - expr.expr, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.expr, + expr_to_alias, + parse_local_name, + config_context=config_context, ), collation_spec, ) @@ -237,8 +241,11 @@ def do_analyze( if parse_local_name and isinstance(field, str): field = field.upper() return subfield_expression( - self.do_analyze( - expr.expr, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.expr, + expr_to_alias, + parse_local_name, + config_context=config_context, ), field, ) @@ -247,17 +254,26 @@ def do_analyze( return case_when_expression( [ ( - self.do_analyze( - condition, config_context, expr_to_alias, parse_local_name + self.analyze( + condition, + expr_to_alias, + parse_local_name, + config_context=config_context, ), - self.do_analyze( - value, config_context, expr_to_alias, parse_local_name + self.analyze( + value, + expr_to_alias, + parse_local_name, + config_context=config_context, ), ) for condition, value in expr.branches ], - self.do_analyze( - expr.else_value, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.else_value, + expr_to_alias, + parse_local_name, + config_context=config_context, ) if expr.else_value else "NULL", @@ -270,15 +286,15 @@ def do_analyze( resolved_expr = self.to_sql_try_avoid_cast( expression, expr_to_alias, - config_context, parse_local_name, + config_context=config_context, ) else: - resolved_expr = self.do_analyze( + resolved_expr = self.analyze( expression, - config_context, expr_to_alias, parse_local_name, + config_context=config_context, ) block_expressions.append(resolved_expr) @@ -291,21 +307,24 @@ def do_analyze( in_value = self.to_sql_try_avoid_cast( expression, expr_to_alias, - config_context, parse_local_name, + config_context=config_context, ) else: - in_value = self.do_analyze( + in_value = self.analyze( expression, - config_context, expr_to_alias, parse_local_name, + config_context=config_context, ) in_values.append(in_value) return in_expression( - self.do_analyze( - expr.columns, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.columns, + expr_to_alias, + parse_local_name, + config_context=config_context, ), in_values, ) @@ -318,14 +337,14 @@ def do_analyze( if isinstance(expr, WindowExpression): return window_expression( - self.do_analyze( + self.analyze( expr.window_function, - config_context, + config_context=config_context, parse_local_name=parse_local_name, ), - self.do_analyze( + self.analyze( expr.window_spec, - config_context, + config_context=config_context, parse_local_name=parse_local_name, ), ) @@ -333,20 +352,24 @@ def do_analyze( if isinstance(expr, WindowSpecDefinition): return window_spec_expression( [ - self.do_analyze( - x, config_context, parse_local_name=parse_local_name + self.analyze( + x, + config_context=config_context, + parse_local_name=parse_local_name, ) for x in expr.partition_spec ], [ - self.do_analyze( - x, config_context, parse_local_name=parse_local_name + self.analyze( + x, + config_context=config_context, + parse_local_name=parse_local_name, ) for x in expr.order_spec ], - self.do_analyze( + self.analyze( expr.frame_spec, - config_context, + config_context=config_context, parse_local_name=parse_local_name, ), ) @@ -406,7 +429,7 @@ def do_analyze( return "*" else: return [ - self.do_analyze(e, config_context, expr_to_alias) + self.analyze(e, expr_to_alias, config_context=config_context) for e in expr.expressions ] @@ -419,7 +442,12 @@ def do_analyze( return function_expression( func_name, [ - self.do_analyze(x, config_context, expr_to_alias, parse_local_name) + self.analyze( + x, + expr_to_alias, + parse_local_name, + config_context=config_context, + ) for x in expr.children ], False, @@ -434,13 +462,23 @@ def do_analyze( return table_function_partition_spec( expr.over, [ - self.do_analyze(x, config_context, expr_to_alias, parse_local_name) + self.analyze( + x, + expr_to_alias, + parse_local_name, + config_context=config_context, + ) for x in expr.partition_spec ] if expr.partition_spec else [], [ - self.do_analyze(x, config_context, expr_to_alias, parse_local_name) + self.analyze( + x, + expr_to_alias, + parse_local_name, + config_context=config_context, + ) for x in expr.order_spec ] if expr.order_spec @@ -451,15 +489,18 @@ def do_analyze( return self.unary_expression_extractor( expr, expr_to_alias, - config_context, parse_local_name, keep_alias=keep_alias, + config_context=config_context, ) if isinstance(expr, SortOrder): return order_expression( - self.do_analyze( - expr.child, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.child, + expr_to_alias, + parse_local_name, + config_context=config_context, ), expr.direction.sql, expr.null_ordering.sql, @@ -471,11 +512,14 @@ def do_analyze( if isinstance(expr, WithinGroup): return within_group_expression( - self.do_analyze( - expr.expr, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.expr, + expr_to_alias, + parse_local_name, + config_context=config_context, ), [ - self.do_analyze(e, config_context, expr_to_alias) + self.analyze(e, expr_to_alias, config_context=config_context) for e in expr.order_by_cols ], ) @@ -484,46 +528,58 @@ def do_analyze( return self.binary_operator_extractor( expr, expr_to_alias, - config_context, parse_local_name, + config_context=config_context, ) if isinstance(expr, InsertMergeExpression): return insert_merge_statement( - self.do_analyze(expr.condition, config_context, expr_to_alias) + self.analyze( + expr.condition, expr_to_alias, config_context=config_context + ) if expr.condition else None, - [self.do_analyze(k, config_context, expr_to_alias) for k in expr.keys], [ - self.do_analyze(v, config_context, expr_to_alias) + self.analyze(k, expr_to_alias, config_context=config_context) + for k in expr.keys + ], + [ + self.analyze(v, expr_to_alias, config_context=config_context) for v in expr.values ], ) if isinstance(expr, UpdateMergeExpression): return update_merge_statement( - self.do_analyze(expr.condition, config_context, expr_to_alias) + self.analyze( + expr.condition, expr_to_alias, config_context=config_context + ) if expr.condition else None, { - self.do_analyze(k, config_context, expr_to_alias): self.do_analyze( - v, config_context, expr_to_alias - ) + self.analyze( + k, expr_to_alias, config_context=config_context + ): self.analyze(v, expr_to_alias, config_context=config_context) for k, v in expr.assignments.items() }, ) if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( - self.do_analyze(expr.condition, config_context, expr_to_alias) + self.analyze( + expr.condition, expr_to_alias, config_context=config_context + ) if expr.condition else None ) if isinstance(expr, ListAgg): return list_agg( - self.do_analyze( - expr.col, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.col, + expr_to_alias, + parse_local_name, + config_context=config_context, ), str_to_sql(expr.delimiter), expr.is_distinct, @@ -532,8 +588,11 @@ def do_analyze( if isinstance(expr, ColumnSum): return column_sum( [ - self.do_analyze( - col, config_context, expr_to_alias, parse_local_name + self.analyze( + col, + expr_to_alias, + parse_local_name, + config_context=config_context, ) for col in expr.exprs ] @@ -542,12 +601,18 @@ def do_analyze( if isinstance(expr, RankRelatedFunctionExpression): return rank_related_function_expression( expr.sql, - self.do_analyze( - expr.expr, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.expr, + expr_to_alias, + parse_local_name, + config_context=config_context, ), expr.offset, - self.do_analyze( - expr.default, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.default, + expr_to_alias, + parse_local_name, + config_context=config_context, ) if expr.default else None, @@ -567,8 +632,11 @@ def table_function_expression_extractor( ) -> str: if isinstance(expr, FlattenFunction): return flatten_expression( - self.do_analyze( - expr.input, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.input, + expr_to_alias, + parse_local_name, + config_context=config_context, ), expr.path, expr.outer, @@ -579,7 +647,12 @@ def table_function_expression_extractor( sql = function_expression( expr.func_name, [ - self.do_analyze(x, config_context, expr_to_alias, parse_local_name) + self.analyze( + x, + expr_to_alias, + parse_local_name, + config_context=config_context, + ) for x in expr.args ], False, @@ -588,8 +661,11 @@ def table_function_expression_extractor( sql = named_arguments_function( expr.func_name, { - key: self.do_analyze( - value, config_context, expr_to_alias, parse_local_name + key: self.analyze( + value, + expr_to_alias, + parse_local_name, + config_context=config_context, ) for key, value in expr.args.items() }, @@ -600,7 +676,9 @@ def table_function_expression_extractor( "NamedArgumentsTableFunction, GeneratorTableFunction, or FlattenFunction." ) partition_spec_sql = ( - self.do_analyze(expr.partition_spec, config_context, expr_to_alias) + self.analyze( + expr.partition_spec, expr_to_alias, config_context=config_context + ) if expr.partition_spec else "" ) @@ -622,8 +700,11 @@ def unary_expression_extractor( if v == expr.child.name: expr_to_alias[k] = quoted_name alias_exp = alias_expression( - self.do_analyze( - expr.child, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.child, + expr_to_alias, + parse_local_name, + config_context=config_context, ), quoted_name, ) @@ -632,24 +713,33 @@ def unary_expression_extractor( expr_str = expr_str.upper() if parse_local_name else expr_str return expr_str if isinstance(expr, UnresolvedAlias): - expr_str = self.do_analyze( - expr.child, config_context, expr_to_alias, parse_local_name + expr_str = self.analyze( + expr.child, + expr_to_alias, + parse_local_name, + config_context=config_context, ) if parse_local_name: expr_str = expr_str.upper() return quote_name(expr_str.strip()) elif isinstance(expr, Cast): return cast_expression( - self.do_analyze( - expr.child, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.child, + expr_to_alias, + parse_local_name, + config_context=config_context, ), expr.to, expr.try_, ) else: return unary_expression( - self.do_analyze( - expr.child, config_context, expr_to_alias, parse_local_name + self.analyze( + expr.child, + expr_to_alias, + parse_local_name, + config_context=config_context, ), expr.sql_operator, expr.operator_first, @@ -664,20 +754,29 @@ def binary_operator_extractor( ) -> str: if config_context.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( - expr.left, expr_to_alias, config_context, parse_local_name + expr.left, + expr_to_alias, + parse_local_name, + config_context=config_context, ) right_sql_expr = self.to_sql_try_avoid_cast( expr.right, expr_to_alias, - config_context, parse_local_name, + config_context=config_context, ) else: - left_sql_expr = self.do_analyze( - expr.left, config_context, expr_to_alias, parse_local_name + left_sql_expr = self.analyze( + expr.left, + expr_to_alias, + parse_local_name, + config_context=config_context, ) - right_sql_expr = self.do_analyze( - expr.right, config_context, expr_to_alias, parse_local_name + right_sql_expr = self.analyze( + expr.right, + expr_to_alias, + parse_local_name, + config_context=config_context, ) operator = expr.sql_operator.lower() @@ -703,14 +802,14 @@ def grouping_extractor( expr_to_alias: Dict[str, str], config_context: ConfigContext, ) -> str: - return self.do_analyze( + return self.analyze( FunctionExpression( expr.pretty_name.upper(), [c.child if isinstance(c, Alias) else c for c in expr.children], False, ), - config_context, expr_to_alias, + config_context=config_context, ) def window_frame_boundary(self, offset: str) -> str: @@ -732,8 +831,8 @@ def to_sql_try_avoid_cast( if isinstance(expr, Literal) and isinstance(expr.datatype, _NumericType): return numeric_to_sql_without_cast(expr.value, expr.datatype) else: - return self.do_analyze( - expr, config_context, expr_to_alias, parse_local_name + return self.analyze( + expr, expr_to_alias, parse_local_name, config_context=config_context ) def resolve( @@ -835,7 +934,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Sort): return self.plan_builder.sort( - list(map(self.do_analyze, logical_plan.order)), + list(map(self.analyze, logical_plan.order)), resolved_children[logical_plan.child], logical_plan, ) @@ -907,8 +1006,10 @@ def do_resolve_with_resolved_children( query=resolved_children[logical_plan.child], stage_location=logical_plan.stage_location, source_plan=logical_plan, - partition_by=self.do_analyze( - logical_plan.partition_by, config_context, expr_to_alias + partition_by=self.analyze( + logical_plan.partition_by, + expr_to_alias, + config_context=config_context, ) if logical_plan.partition_by else None, From 9d62017872fe9dff1fa63301ad4ad4b410ea27bd Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 14:05:41 -0700 Subject: [PATCH 41/62] fix --- .../snowpark/_internal/analyzer/analyzer.py | 88 ++++++++++++++----- src/snowflake/snowpark/mock/_analyzer.py | 5 +- 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 6f85552a5f4..fbf34014644 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -467,7 +467,9 @@ def analyze( return ",".join( [ self.analyze( - e, df_aliased_col_name_to_real_col_name, config_context + e, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for e in expr.expressions ] @@ -563,7 +565,9 @@ def analyze( ), [ self.analyze( - e, df_aliased_col_name_to_real_col_name, config_context + e, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for e in expr.order_by_cols ], @@ -580,19 +584,25 @@ def analyze( if isinstance(expr, InsertMergeExpression): return insert_merge_statement( self.analyze( - expr.condition, df_aliased_col_name_to_real_col_name, config_context + expr.condition, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) if expr.condition else None, [ self.analyze( - k, df_aliased_col_name_to_real_col_name, config_context + k, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for k in expr.keys ], [ self.analyze( - v, df_aliased_col_name_to_real_col_name, config_context + v, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for v in expr.values ], @@ -601,15 +611,21 @@ def analyze( if isinstance(expr, UpdateMergeExpression): return update_merge_statement( self.analyze( - expr.condition, df_aliased_col_name_to_real_col_name, config_context + expr.condition, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) if expr.condition else None, { self.analyze( - k, df_aliased_col_name_to_real_col_name, config_context + k, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ): self.analyze( - v, df_aliased_col_name_to_real_col_name, config_context + v, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for k, v in expr.assignments.items() }, @@ -618,7 +634,9 @@ def analyze( if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( self.analyze( - expr.condition, df_aliased_col_name_to_real_col_name, config_context + expr.condition, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) if expr.condition else None @@ -1013,13 +1031,17 @@ def do_resolve_with_resolved_children( return self.plan_builder.aggregate( [ self.to_sql_try_avoid_cast( - expr, df_aliased_col_name_to_real_col_name, config_context + expr, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for expr in logical_plan.grouping_expressions ], [ self.analyze( - expr, df_aliased_col_name_to_real_col_name, config_context + expr, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for expr in logical_plan.aggregate_expressions ], @@ -1032,7 +1054,9 @@ def do_resolve_with_resolved_children( list( map( lambda x: self.analyze( - x, df_aliased_col_name_to_real_col_name, config_context + x, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ), logical_plan.project_list, ) @@ -1094,7 +1118,9 @@ def do_resolve_with_resolved_children( return self.plan_builder.sort( [ self.analyze( - x, df_aliased_col_name_to_real_col_name, config_context + x, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for x in logical_plan.order ], @@ -1160,7 +1186,9 @@ def do_resolve_with_resolved_children( table_type=logical_plan.table_type, clustering_keys=[ self.analyze( - x, df_aliased_col_name_to_real_col_name, config_context + x, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for x in logical_plan.clustering_exprs ], @@ -1220,7 +1248,9 @@ def do_resolve_with_resolved_children( child = self.plan_builder.project( [ self.analyze( - col, df_aliased_col_name_to_real_col_name, config_context + col, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for col in project_exprs ], @@ -1238,7 +1268,9 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan.pivot_values, List): pivot_values = [ self.analyze( - pv, df_aliased_col_name_to_real_col_name, config_context + pv, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for pv in logical_plan.pivot_values ] @@ -1292,7 +1324,9 @@ def do_resolve_with_resolved_children( logical_plan.name_column, [ self.analyze( - c, df_aliased_col_name_to_real_col_name, config_context + c, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for c in logical_plan.column_list ], @@ -1336,7 +1370,9 @@ def do_resolve_with_resolved_children( initialize=logical_plan.initialize, clustering_keys=[ self.analyze( - x, df_aliased_col_name_to_real_col_name, config_context + x, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for x in logical_plan.clustering_exprs ], @@ -1370,7 +1406,9 @@ def do_resolve_with_resolved_children( column_names=logical_plan.column_names, transformations=[ self.analyze( - x, df_aliased_col_name_to_real_col_name, config_context + x, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for x in logical_plan.transformations ] @@ -1405,9 +1443,13 @@ def do_resolve_with_resolved_children( logical_plan.table_name, { self.analyze( - k, df_aliased_col_name_to_real_col_name, config_context + k, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ): self.analyze( - v, df_aliased_col_name_to_real_col_name, config_context + v, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for k, v in logical_plan.assignments.items() }, @@ -1454,7 +1496,9 @@ def do_resolve_with_resolved_children( ), [ self.analyze( - c, df_aliased_col_name_to_real_col_name, config_context + c, + df_aliased_col_name_to_real_col_name, + config_context=config_context, ) for c in logical_plan.clauses ], diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index dce1b13dc18..9234946b4a8 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -876,7 +876,10 @@ def do_resolve( expr_to_alias.update({p: q for p, q in v.items() if counts[p] < 2}) return self.do_resolve_with_resolved_children( - logical_plan, resolved_children, expr_to_alias, config_context + logical_plan, + resolved_children, + expr_to_alias, + config_context=config_context, ) def do_resolve_with_resolved_children( From b58aa8b8b1666f044495d5494f4f8bbd5962fb8a Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 14:16:02 -0700 Subject: [PATCH 42/62] fix --- .../snowpark/_internal/analyzer/analyzer.py | 10 +++++----- src/snowflake/snowpark/mock/_analyzer.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index fbf34014644..978403127a4 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -299,8 +299,8 @@ def analyze( resolved_expr = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, - config_context=config_context, ) else: resolved_expr = self.analyze( @@ -320,8 +320,8 @@ def analyze( in_value = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, + config_context, parse_local_name, - config_context=config_context, ) else: in_value = self.analyze( @@ -1033,7 +1033,7 @@ def do_resolve_with_resolved_children( self.to_sql_try_avoid_cast( expr, df_aliased_col_name_to_real_col_name, - config_context=config_context, + config_context, ) for expr in logical_plan.grouping_expressions ], @@ -1214,12 +1214,12 @@ def do_resolve_with_resolved_children( self.to_sql_try_avoid_cast( logical_plan.limit_expr, df_aliased_col_name_to_real_col_name, - config_context=config_context, + config_context, ), self.to_sql_try_avoid_cast( logical_plan.offset_expr, df_aliased_col_name_to_real_col_name, - config_context=config_context, + config_context, ), resolved_children[logical_plan.child], on_top_of_order_by, diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index 9234946b4a8..ec0cccab06c 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -286,8 +286,8 @@ def analyze( resolved_expr = self.to_sql_try_avoid_cast( expression, expr_to_alias, + config_context, parse_local_name, - config_context=config_context, ) else: resolved_expr = self.analyze( @@ -307,8 +307,8 @@ def analyze( in_value = self.to_sql_try_avoid_cast( expression, expr_to_alias, + config_context, parse_local_name, - config_context=config_context, ) else: in_value = self.analyze( @@ -489,9 +489,9 @@ def analyze( return self.unary_expression_extractor( expr, expr_to_alias, + config_context, parse_local_name, keep_alias=keep_alias, - config_context=config_context, ) if isinstance(expr, SortOrder): @@ -528,8 +528,8 @@ def analyze( return self.binary_operator_extractor( expr, expr_to_alias, + config_context, parse_local_name, - config_context=config_context, ) if isinstance(expr, InsertMergeExpression): @@ -756,14 +756,14 @@ def binary_operator_extractor( left_sql_expr = self.to_sql_try_avoid_cast( expr.left, expr_to_alias, + config_context, parse_local_name, - config_context=config_context, ) right_sql_expr = self.to_sql_try_avoid_cast( expr.right, expr_to_alias, + config_context, parse_local_name, - config_context=config_context, ) else: left_sql_expr = self.analyze( @@ -862,7 +862,7 @@ def do_resolve( expr_to_alias_maps = {} for c in logical_plan.children: _expr_to_alias = {} - resolved_children[c] = self.resolve(c, _expr_to_alias) + resolved_children[c] = self.resolve(c, _expr_to_alias, config_context) expr_to_alias_maps[c] = _expr_to_alias # get counts of expr_to_alias keys From db3703365c18c9d19e8e2ad8f7b33183a64b2320 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 14:40:05 -0700 Subject: [PATCH 43/62] fix --- .../compiler/large_query_breakdown.py | 12 ++++++-- .../_internal/compiler/plan_compiler.py | 30 +++++++++---------- .../_internal/compiler/query_generator.py | 7 +++-- .../snowpark/_internal/compiler/utils.py | 27 ++++++++++++----- 4 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py index ac9f24cf532..9c3dd5f9f8a 100644 --- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py +++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py @@ -14,6 +14,7 @@ Intersect, Union, ) +from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( get_complexity_score, ) @@ -113,12 +114,14 @@ def __init__( session: Session, query_generator: QueryGenerator, logical_plans: List[LogicalPlan], - complexity_bounds: Tuple[int, int], + config_context: ConfigContext, ) -> None: self.session = session self._query_generator = query_generator self.logical_plans = logical_plans self._parent_map = defaultdict(set) + self._config_context = config_context + complexity_bounds = config_context.large_query_breakdown_complexity_bounds self.complexity_score_lower_bound = complexity_bounds[0] self.complexity_score_upper_bound = complexity_bounds[1] @@ -139,7 +142,9 @@ def apply(self) -> List[LogicalPlan]: # Similar to the repeated subquery elimination, we rely on # nodes of the plan to be SnowflakePlan or Selectable. Here, # we resolve the plan to make sure we get a valid plan tree. - resolved_plan = self._query_generator.resolve(logical_plan) + resolved_plan = self._query_generator.resolve( + logical_plan, self._config_context + ) partition_plans = self._try_to_breakdown_plan(resolved_plan) resulting_plans.extend(partition_plans) @@ -259,7 +264,8 @@ def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePla child, table_type="temp", creation_source=TableCreationSource.LARGE_QUERY_BREAKDOWN, - ) + ), + self._config_context, ) # Update the ancestors with the temp table selectable diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 178afa5a941..f3b14b39b5c 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -47,7 +47,7 @@ class PlanCompiler: def __init__(self, plan: SnowflakePlan) -> None: self._plan = plan - self.config_context = ConfigContext(self._plan.session) + self._config_context = ConfigContext(self._plan.session) def should_start_query_compilation(self) -> bool: """ @@ -67,10 +67,10 @@ def should_start_query_compilation(self) -> bool: return ( not isinstance(current_session._conn, MockServerConnection) and (self._plan.source_plan is not None) - and self.config_context._query_compilation_stage_enabled + and self._config_context._query_compilation_stage_enabled and ( - self.config_context.cte_optimization_enabled - or self.config_context.large_query_breakdown_enabled + self._config_context.cte_optimization_enabled + or self._config_context.large_query_breakdown_enabled ) ) @@ -86,14 +86,14 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: deep_copy_end_time = time.time() # 2. create a code generator with the original plan - query_generator = create_query_generator(self._plan) + query_generator = create_query_generator(self._plan, self._config_context) # 3. apply each optimizations if needed # CTE optimization cte_start_time = time.time() - if self.config_context.cte_optimization_enabled: + if self._config_context.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( - logical_plans, query_generator, self.config_context + logical_plans, query_generator, self._config_context ) logical_plans = repeated_subquery_eliminator.apply() @@ -104,12 +104,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: ] # Large query breakdown - if self.config_context.large_query_breakdown_enabled: + if self._config_context.large_query_breakdown_enabled: large_query_breakdown = LargeQueryBreakdown( self._plan.session, query_generator, logical_plans, - self.config_context.large_query_breakdown_complexity_bounds, + self._config_context, ) logical_plans = large_query_breakdown.apply() @@ -120,9 +120,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: ] # 4. do a final pass of code generation - queries = query_generator.generate_queries( - logical_plans, self.config_context - ) + queries = query_generator.generate_queries(logical_plans) # log telemetry data deep_copy_time = deep_copy_end_time - start_time @@ -131,9 +129,9 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: total_time = time.time() - start_time session = self._plan.session summary_value = { - TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.config_context.cte_optimization_enabled, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.config_context.large_query_breakdown_enabled, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self.config_context.large_query_breakdown_complexity_bounds, + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self._config_context.cte_optimization_enabled, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self._config_context.large_query_breakdown_enabled, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self._config_context.large_query_breakdown_complexity_bounds, CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, @@ -151,7 +149,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: else: final_plan = self._plan final_plan = final_plan.replace_repeated_subquery_with_cte( - self.config_context + self._config_context ) return { PlanQueryType.QUERIES: final_plan.queries, diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index 98920c8b1a7..dc6654274f5 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -50,6 +50,7 @@ class QueryGenerator(Analyzer): def __init__( self, session: Session, + config_context: ConfigContext, snowflake_create_table_plan_info: Optional[SnowflakeCreateTablePlanInfo] = None, ) -> None: super().__init__(session) @@ -59,6 +60,7 @@ def __init__( self._snowflake_create_table_plan_info: Optional[ SnowflakeCreateTablePlanInfo ] = snowflake_create_table_plan_info + self.config_context = config_context # Records the definition of all the with query blocks encountered during the code generation. # This information will be used to generate the final query of a SnowflakePlan with the # correct CTE definition. @@ -68,7 +70,7 @@ def __init__( self.resolved_with_query_block: Dict[str, Query] = {} def generate_queries( - self, logical_plans: List[LogicalPlan], config_context: ConfigContext + self, logical_plans: List[LogicalPlan] ) -> Dict[PlanQueryType, List[Query]]: """ Generate final queries for the given set of logical plans. @@ -83,7 +85,8 @@ def generate_queries( # generate queries for each logical plan snowflake_plans = [ - self.resolve(logical_plan, config_context) for logical_plan in logical_plans + self.resolve(logical_plan, self.config_context) + for logical_plan in logical_plans ] # merge all results into final set of queries queries = [] diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index 82c2b090487..7829f354ffa 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Union from snowflake.snowpark._internal.analyzer.binary_plan_node import BinaryNode +from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.select_statement import ( Selectable, SelectSnowflakePlan, @@ -39,7 +40,9 @@ TreeNode = Union[SnowflakePlan, Selectable] -def create_query_generator(plan: SnowflakePlan) -> QueryGenerator: +def create_query_generator( + plan: SnowflakePlan, config_context: ConfigContext +) -> QueryGenerator: """ Helper function to construct the query generator for a given valid SnowflakePlan. """ @@ -64,12 +67,16 @@ def create_query_generator(plan: SnowflakePlan) -> QueryGenerator: # resolved plan, and the resolve will be a no-op. # NOTE that here we rely on the fact that the SnowflakeCreateTable node is the root # of a source plan. Test will fail if that assumption is broken. - resolved_child = plan.session._analyzer.resolve(create_table_node.query) + resolved_child = plan.session._analyzer.resolve( + create_table_node.query, config_context + ) snowflake_create_table_plan_info = SnowflakeCreateTablePlanInfo( create_table_node.table_name, resolved_child.attributes ) - return QueryGenerator(plan.session, snowflake_create_table_plan_info) + return QueryGenerator( + plan.session, config_context, snowflake_create_table_plan_info + ) def resolve_and_update_snowflake_plan( @@ -83,7 +90,9 @@ def resolve_and_update_snowflake_plan( if node.source_plan is None: return - new_snowflake_plan = query_generator.resolve(node.source_plan) + new_snowflake_plan = query_generator.resolve( + node.source_plan, query_generator.config_context + ) # copy over the newly resolved fields to make it an in-place update node.queries = new_snowflake_plan.queries @@ -117,7 +126,7 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta if isinstance(plan, Selectable): return plan - snowflake_plan = query_generator.resolve(plan) + snowflake_plan = query_generator.resolve(plan, query_generator.config_context) return SelectSnowflakePlan(snowflake_plan, analyzer=query_generator) if not parent._is_valid_for_replacement: @@ -165,12 +174,16 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta parent.query = new_child elif isinstance(parent, (TableUpdate, TableDelete)): - snowflake_plan = query_generator.resolve(new_child) + snowflake_plan = query_generator.resolve( + new_child, query_generator.config_context + ) parent.children = [snowflake_plan] parent.source_data = snowflake_plan elif isinstance(parent, TableMerge): - snowflake_plan = query_generator.resolve(new_child) + snowflake_plan = query_generator.resolve( + new_child, query_generator.config_context + ) parent.children = [snowflake_plan] parent.source = snowflake_plan From dddd15f9c7fd6d7509aa6a9b6c87e25517a76202 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 15:02:29 -0700 Subject: [PATCH 44/62] fix unit tests --- .../compiler/test_replace_child_and_update_node.py | 13 +++++++++++-- tests/unit/conftest.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index 05098165a1b..7c3599df6cb 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -8,6 +8,7 @@ import pytest from snowflake.snowpark._internal.analyzer.binary_plan_node import Inner, Join, Union +from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.select_statement import ( Selectable, SelectableEntity, @@ -69,8 +70,15 @@ def mock_snowflake_plan() -> SnowflakePlan: @pytest.fixture(scope="function") -def mock_query_generator(mock_session) -> QueryGenerator: - def mock_resolve(x): +def mock_config_context() -> ConfigContext: + fake_config_context = mock.create_autospec(ConfigContext) + fake_config_context._query_compilation_stage_enabled = False + # fake_config_context.cte_optimization_enabled = False + + +@pytest.fixture(scope="function") +def mock_query_generator(mock_session, mock_config_context) -> QueryGenerator: + def mock_resolve(x, y): snowflake_plan = mock_snowflake_plan() snowflake_plan.source_plan = x if hasattr(x, "post_actions"): @@ -80,6 +88,7 @@ def mock_resolve(x): fake_query_generator = mock.create_autospec(QueryGenerator) fake_query_generator.resolve.side_effect = mock_resolve fake_query_generator.session = mock_session + fake_query_generator.config_context = mock_config_context return fake_query_generator diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 986927b65e4..c6c0b0cb508 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -48,7 +48,7 @@ def mock_snowflake_plan(mock_query) -> Analyzer: @pytest.fixture(scope="module") def mock_analyzer(mock_snowflake_plan) -> Analyzer: - def mock_resolve(x): + def mock_resolve(x, y=None): mock_snowflake_plan.source_plan = x return mock_snowflake_plan From 57ee9e85cfd5d66f8b7aa5352489cb6976a23768 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 15:21:44 -0700 Subject: [PATCH 45/62] simplify --- .../_internal/compiler/large_query_breakdown.py | 12 +++--------- .../snowpark/_internal/compiler/plan_compiler.py | 4 ++-- .../snowpark/_internal/compiler/query_generator.py | 7 +++++-- .../compiler/repeated_subquery_elimination.py | 11 ++--------- src/snowflake/snowpark/_internal/compiler/utils.py | 14 ++++---------- 5 files changed, 16 insertions(+), 32 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py index 9c3dd5f9f8a..ac9f24cf532 100644 --- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py +++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py @@ -14,7 +14,6 @@ Intersect, Union, ) -from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( get_complexity_score, ) @@ -114,14 +113,12 @@ def __init__( session: Session, query_generator: QueryGenerator, logical_plans: List[LogicalPlan], - config_context: ConfigContext, + complexity_bounds: Tuple[int, int], ) -> None: self.session = session self._query_generator = query_generator self.logical_plans = logical_plans self._parent_map = defaultdict(set) - self._config_context = config_context - complexity_bounds = config_context.large_query_breakdown_complexity_bounds self.complexity_score_lower_bound = complexity_bounds[0] self.complexity_score_upper_bound = complexity_bounds[1] @@ -142,9 +139,7 @@ def apply(self) -> List[LogicalPlan]: # Similar to the repeated subquery elimination, we rely on # nodes of the plan to be SnowflakePlan or Selectable. Here, # we resolve the plan to make sure we get a valid plan tree. - resolved_plan = self._query_generator.resolve( - logical_plan, self._config_context - ) + resolved_plan = self._query_generator.resolve(logical_plan) partition_plans = self._try_to_breakdown_plan(resolved_plan) resulting_plans.extend(partition_plans) @@ -264,8 +259,7 @@ def _get_partitioned_plan(self, root: TreeNode, child: TreeNode) -> SnowflakePla child, table_type="temp", creation_source=TableCreationSource.LARGE_QUERY_BREAKDOWN, - ), - self._config_context, + ) ) # Update the ancestors with the temp table selectable diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index f3b14b39b5c..7a2d878c191 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -93,7 +93,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: cte_start_time = time.time() if self._config_context.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( - logical_plans, query_generator, self._config_context + logical_plans, query_generator ) logical_plans = repeated_subquery_eliminator.apply() @@ -109,7 +109,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: self._plan.session, query_generator, logical_plans, - self._config_context, + self._config_context.large_query_breakdown_complexity_bounds, ) logical_plans = large_query_breakdown.apply() diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index dc6654274f5..8bfc298dcdf 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -60,7 +60,7 @@ def __init__( self._snowflake_create_table_plan_info: Optional[ SnowflakeCreateTablePlanInfo ] = snowflake_create_table_plan_info - self.config_context = config_context + self._config_context = config_context # Records the definition of all the with query blocks encountered during the code generation. # This information will be used to generate the final query of a SnowflakePlan with the # correct CTE definition. @@ -85,7 +85,7 @@ def generate_queries( # generate queries for each logical plan snowflake_plans = [ - self.resolve(logical_plan, self.config_context) + self.resolve(logical_plan, self._config_context) for logical_plan in logical_plans ] # merge all results into final set of queries @@ -111,6 +111,9 @@ def generate_queries( PlanQueryType.POST_ACTIONS: post_actions, } + def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: + return super().resolve(logical_plan, self._config_context) + def do_resolve_with_resolved_children( self, logical_plan: LogicalPlan, diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index 9a7e36ee9fb..38e3b72a32b 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Set -from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.cte_utils import find_duplicate_subtrees from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( LogicalPlan, @@ -50,11 +49,9 @@ def __init__( self, logical_plans: List[LogicalPlan], query_generator: QueryGenerator, - config_context: ConfigContext, ) -> None: self._logical_plans = logical_plans self._query_generator = query_generator - self._config_context = config_context def apply(self) -> List[LogicalPlan]: """ @@ -70,9 +67,7 @@ def apply(self) -> List[LogicalPlan]: # do a pass of resolve of the logical plan to make sure we get a valid # resolved plan to start the process. # If the plan is already a resolved plan, this step will be a no-op. - logical_plan = self._query_generator.resolve( - logical_plan, self._config_context - ) + logical_plan = self._query_generator.resolve(logical_plan) # apply the CTE optimization on the resolved plan duplicated_nodes, node_parents_map = find_duplicate_subtrees(logical_plan) @@ -144,9 +139,7 @@ def _update_parents( ) with_block._is_valid_for_replacement = True - resolved_with_block = self._query_generator.resolve( - with_block, self._config_context - ) + resolved_with_block = self._query_generator.resolve(with_block) _update_parents( node, should_replace_child=True, new_child=resolved_with_block ) diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index 7829f354ffa..fecf6c36a3e 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -90,9 +90,7 @@ def resolve_and_update_snowflake_plan( if node.source_plan is None: return - new_snowflake_plan = query_generator.resolve( - node.source_plan, query_generator.config_context - ) + new_snowflake_plan = query_generator.resolve(node.source_plan) # copy over the newly resolved fields to make it an in-place update node.queries = new_snowflake_plan.queries @@ -126,7 +124,7 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta if isinstance(plan, Selectable): return plan - snowflake_plan = query_generator.resolve(plan, query_generator.config_context) + snowflake_plan = query_generator.resolve(plan) return SelectSnowflakePlan(snowflake_plan, analyzer=query_generator) if not parent._is_valid_for_replacement: @@ -174,16 +172,12 @@ def to_selectable(plan: LogicalPlan, query_generator: QueryGenerator) -> Selecta parent.query = new_child elif isinstance(parent, (TableUpdate, TableDelete)): - snowflake_plan = query_generator.resolve( - new_child, query_generator.config_context - ) + snowflake_plan = query_generator.resolve(new_child) parent.children = [snowflake_plan] parent.source_data = snowflake_plan elif isinstance(parent, TableMerge): - snowflake_plan = query_generator.resolve( - new_child, query_generator.config_context - ) + snowflake_plan = query_generator.resolve(new_child) parent.children = [snowflake_plan] parent.source = snowflake_plan From 809a86ecb23e00d18155fc292a98d1f31211579f Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 15:25:39 -0700 Subject: [PATCH 46/62] simplify --- .../snowpark/_internal/compiler/query_generator.py | 5 +---- .../compiler/test_replace_child_and_update_node.py | 11 ++++------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index 8bfc298dcdf..d8d2762f834 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -84,10 +84,7 @@ def generate_queries( ) # generate queries for each logical plan - snowflake_plans = [ - self.resolve(logical_plan, self._config_context) - for logical_plan in logical_plans - ] + snowflake_plans = [self.resolve(logical_plan) for logical_plan in logical_plans] # merge all results into final set of queries queries = [] post_actions = [] diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index 7c3599df6cb..1c8851715df 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -69,15 +69,10 @@ def mock_snowflake_plan() -> SnowflakePlan: return fake_snowflake_plan -@pytest.fixture(scope="function") -def mock_config_context() -> ConfigContext: - fake_config_context = mock.create_autospec(ConfigContext) - fake_config_context._query_compilation_stage_enabled = False - # fake_config_context.cte_optimization_enabled = False @pytest.fixture(scope="function") -def mock_query_generator(mock_session, mock_config_context) -> QueryGenerator: +def mock_query_generator(mock_session) -> QueryGenerator: def mock_resolve(x, y): snowflake_plan = mock_snowflake_plan() snowflake_plan.source_plan = x @@ -88,7 +83,9 @@ def mock_resolve(x, y): fake_query_generator = mock.create_autospec(QueryGenerator) fake_query_generator.resolve.side_effect = mock_resolve fake_query_generator.session = mock_session - fake_query_generator.config_context = mock_config_context + fake_config_context = mock.create_autospec(ConfigContext) + fake_config_context._query_compilation_stage_enabled = False + fake_query_generator.config_context = fake_config_context return fake_query_generator From 6021ab8117299c79a22b91f5b600db6e00a07862 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 15:26:34 -0700 Subject: [PATCH 47/62] simplify --- tests/unit/compiler/test_replace_child_and_update_node.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index 1c8851715df..558e47530c3 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -69,8 +69,6 @@ def mock_snowflake_plan() -> SnowflakePlan: return fake_snowflake_plan - - @pytest.fixture(scope="function") def mock_query_generator(mock_session) -> QueryGenerator: def mock_resolve(x, y): From 43986f6a132e00109da6938cde8d5958bc6a79d2 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 15:29:06 -0700 Subject: [PATCH 48/62] simplify --- tests/integ/compiler/test_query_generator.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/integ/compiler/test_query_generator.py b/tests/integ/compiler/test_query_generator.py index cb203487415..5ce4c005ad3 100644 --- a/tests/integ/compiler/test_query_generator.py +++ b/tests/integ/compiler/test_query_generator.py @@ -89,7 +89,7 @@ def check_generated_plan_queries(plan: SnowflakePlan) -> None: assert plan.queries is None assert plan.post_actions is None # regenerate the queries - plan_queries = query_generator.generate_queries([source_plan], config_context=None) + plan_queries = query_generator.generate_queries([source_plan]) queries = [query.sql for query in plan_queries[PlanQueryType.QUERIES]] post_actions = [query.sql for query in plan_queries[PlanQueryType.POST_ACTIONS]] assert queries == original_queries @@ -191,7 +191,7 @@ def test_table_create_from_large_query_breakdown(session, plan_source_generator) comment=None, ) - queries = generator.generate_queries([create_table_source], config_context=None) + queries = generator.generate_queries([create_table_source]) assert len(queries[PlanQueryType.QUERIES]) == 1 assert len(queries[PlanQueryType.POST_ACTIONS]) == 0 @@ -321,9 +321,7 @@ def verify_multiple_create_queries( # reset the whole plan reset_plan_tree(df._plan) # regenerate the queries - plan_queries = query_generator.generate_queries( - [df._plan.source_plan], config_context=None - ) + plan_queries = query_generator.generate_queries([df._plan.source_plan]) queries = [query.sql.lstrip() for query in plan_queries[PlanQueryType.QUERIES]] post_actions = [ query.sql.lstrip() for query in plan_queries[PlanQueryType.POST_ACTIONS] @@ -364,9 +362,7 @@ def test_multiple_plan_query_generation(session): reset_plan_tree(snowflake_plan) reset_plan_tree(df_res._plan) logical_plans = [snowflake_plan.source_plan, df_res._plan.source_plan] - generated_queries = query_generator.generate_queries( - logical_plans, config_context=None - ) + generated_queries = query_generator.generate_queries(logical_plans) result_queries = [ query.sql.lstrip() for query in generated_queries[PlanQueryType.QUERIES] ] From 0430e923f3b374e05149a37130d93315141d3122 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 27 Sep 2024 15:41:36 -0700 Subject: [PATCH 49/62] simplify --- src/snowflake/snowpark/mock/_analyzer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index ec0cccab06c..3f59859ab94 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -339,13 +339,13 @@ def analyze( return window_expression( self.analyze( expr.window_function, - config_context=config_context, parse_local_name=parse_local_name, + config_context=config_context, ), self.analyze( expr.window_spec, - config_context=config_context, parse_local_name=parse_local_name, + config_context=config_context, ), ) @@ -354,23 +354,23 @@ def analyze( [ self.analyze( x, - config_context=config_context, parse_local_name=parse_local_name, + config_context=config_context, ) for x in expr.partition_spec ], [ self.analyze( x, - config_context=config_context, parse_local_name=parse_local_name, + config_context=config_context, ) for x in expr.order_spec ], self.analyze( expr.frame_spec, - config_context=config_context, parse_local_name=parse_local_name, + config_context=config_context, ), ) From 095b04ec24946505eb34f0012b9a240203ddf6d3 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Mon, 30 Sep 2024 10:00:16 -0700 Subject: [PATCH 50/62] remove config context --- .../snowpark/_internal/analyzer/analyzer.py | 135 +++-------------- .../_internal/analyzer/config_context.py | 37 ----- .../_internal/analyzer/snowflake_plan.py | 59 ++------ .../_internal/compiler/plan_compiler.py | 28 ++-- .../_internal/compiler/query_generator.py | 10 +- .../snowpark/_internal/compiler/utils.py | 11 +- src/snowflake/snowpark/mock/_analyzer.py | 143 ++++++++---------- tests/integ/test_multithreading.py | 50 ------ .../test_replace_child_and_update_node.py | 6 +- tests/unit/conftest.py | 2 +- 10 files changed, 113 insertions(+), 368 deletions(-) delete mode 100644 src/snowflake/snowpark/_internal/analyzer/config_context.py diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 978403127a4..7ca2b883ef0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -47,7 +47,6 @@ BinaryExpression, ) from snowflake.snowpark._internal.analyzer.binary_plan_node import Join, SetOperation -from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.datatype_mapper import ( numeric_to_sql_without_cast, str_to_sql, @@ -173,12 +172,7 @@ def analyze( expr: Union[Expression, NamedExpression], df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], parse_local_name=False, - config_context: Optional[ConfigContext] = None, ) -> str: - # Set the config context for analysis step - if config_context is None: - config_context = ConfigContext(self.session) - if isinstance(expr, GroupingSetsExpression): return grouping_set_expression( [ @@ -187,7 +181,6 @@ def analyze( a, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) for a in arg ] @@ -201,13 +194,11 @@ def analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), self.analyze( expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), ) @@ -217,19 +208,16 @@ def analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), self.analyze( expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), self.analyze( expr.parameters, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) if expr.parameters is not None else None, @@ -244,7 +232,6 @@ def analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), collation_spec, ) @@ -258,7 +245,6 @@ def analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), field, ) @@ -271,13 +257,11 @@ def analyze( condition, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), self.analyze( value, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), ) for condition, value in expr.branches @@ -286,7 +270,6 @@ def analyze( expr.else_value, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) if expr.else_value else "NULL", @@ -295,11 +278,10 @@ def analyze( if isinstance(expr, MultipleExpression): block_expressions = [] for expression in expr.expressions: - if config_context.eliminate_numeric_sql_value_cast_enabled: + if self.session.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, ) else: @@ -307,7 +289,6 @@ def analyze( expression, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) block_expressions.append(resolved_expr) @@ -316,11 +297,10 @@ def analyze( if isinstance(expr, InExpression): in_values = [] for expression in expr.values: - if config_context.eliminate_numeric_sql_value_cast_enabled: + if self.session.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( expression, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, ) else: @@ -328,7 +308,6 @@ def analyze( expression, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) in_values.append(in_value) @@ -337,15 +316,12 @@ def analyze( expr.columns, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), in_values, ) if isinstance(expr, GroupingSet): - return self.grouping_extractor( - expr, df_aliased_col_name_to_real_col_name, config_context - ) + return self.grouping_extractor(expr, df_aliased_col_name_to_real_col_name) if isinstance(expr, WindowExpression): return window_expression( @@ -353,13 +329,11 @@ def analyze( expr.window_function, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), self.analyze( expr.window_spec, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), ) if isinstance(expr, WindowSpecDefinition): @@ -369,7 +343,6 @@ def analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) for x in expr.partition_spec ], @@ -378,7 +351,6 @@ def analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) for x in expr.order_spec ], @@ -386,7 +358,6 @@ def analyze( expr.frame_spec, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), ) if isinstance(expr, SpecifiedWindowFrame): @@ -394,12 +365,12 @@ def analyze( expr.frame_type.sql, self.window_frame_boundary( self.to_sql_try_avoid_cast( - expr.lower, df_aliased_col_name_to_real_col_name, config_context + expr.lower, df_aliased_col_name_to_real_col_name ) ), self.window_frame_boundary( self.to_sql_try_avoid_cast( - expr.upper, df_aliased_col_name_to_real_col_name, config_context + expr.upper, df_aliased_col_name_to_real_col_name ) ), ) @@ -443,9 +414,7 @@ def analyze( return function_expression( func_name, [ - self.to_sql_try_avoid_cast( - c, df_aliased_col_name_to_real_col_name, config_context - ) + self.to_sql_try_avoid_cast(c, df_aliased_col_name_to_real_col_name) for c in expr.children ], expr.is_distinct, @@ -469,7 +438,6 @@ def analyze( self.analyze( e, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for e in expr.expressions ] @@ -488,7 +456,6 @@ def analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) for x in expr.children ], @@ -501,7 +468,7 @@ def analyze( expr.api_call_source, TelemetryField.FUNC_CAT_USAGE.value ) return self.table_function_expression_extractor( - expr, df_aliased_col_name_to_real_col_name, config_context + expr, df_aliased_col_name_to_real_col_name ) if isinstance(expr, TableFunctionPartitionSpecDefinition): @@ -512,7 +479,6 @@ def analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) for x in expr.partition_spec ] @@ -523,7 +489,6 @@ def analyze( x, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) for x in expr.order_spec ] @@ -535,7 +500,6 @@ def analyze( return self.unary_expression_extractor( expr, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, ) @@ -545,7 +509,6 @@ def analyze( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), expr.direction.sql, expr.null_ordering.sql, @@ -561,13 +524,11 @@ def analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), [ self.analyze( e, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for e in expr.order_by_cols ], @@ -577,7 +538,6 @@ def analyze( return self.binary_operator_extractor( expr, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, ) @@ -586,7 +546,6 @@ def analyze( self.analyze( expr.condition, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if expr.condition else None, @@ -594,7 +553,6 @@ def analyze( self.analyze( k, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for k in expr.keys ], @@ -602,7 +560,6 @@ def analyze( self.analyze( v, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for v in expr.values ], @@ -613,7 +570,6 @@ def analyze( self.analyze( expr.condition, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if expr.condition else None, @@ -621,11 +577,9 @@ def analyze( self.analyze( k, df_aliased_col_name_to_real_col_name, - config_context=config_context, ): self.analyze( v, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for k, v in expr.assignments.items() }, @@ -636,7 +590,6 @@ def analyze( self.analyze( expr.condition, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if expr.condition else None @@ -648,7 +601,6 @@ def analyze( expr.col, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), str_to_sql(expr.delimiter), expr.is_distinct, @@ -661,7 +613,6 @@ def analyze( col, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) for col in expr.exprs ] @@ -674,14 +625,12 @@ def analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), expr.offset, self.analyze( expr.default, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) if expr.default else None, @@ -696,7 +645,6 @@ def table_function_expression_extractor( self, expr: TableFunctionExpression, df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], - config_context: ConfigContext, parse_local_name=False, ) -> str: if isinstance(expr, FlattenFunction): @@ -705,7 +653,6 @@ def table_function_expression_extractor( expr.input, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), expr.path, expr.outer, @@ -720,7 +667,6 @@ def table_function_expression_extractor( x, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) for x in expr.args ], @@ -733,7 +679,6 @@ def table_function_expression_extractor( key: self.to_sql_try_avoid_cast( value, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, ) for key, value in expr.args.items() @@ -748,7 +693,6 @@ def table_function_expression_extractor( self.analyze( expr.partition_spec, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if expr.partition_spec else "" @@ -759,7 +703,6 @@ def unary_expression_extractor( self, expr: UnaryExpression, df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], - config_context: ConfigContext, parse_local_name=False, ) -> str: if isinstance(expr, Alias): @@ -780,7 +723,6 @@ def unary_expression_extractor( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), quoted_name, ) @@ -789,7 +731,6 @@ def unary_expression_extractor( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) if parse_local_name: expr_str = expr_str.upper() @@ -800,7 +741,6 @@ def unary_expression_extractor( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), expr.to, expr.try_, @@ -811,7 +751,6 @@ def unary_expression_extractor( expr.child, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ), expr.sql_operator, expr.operator_first, @@ -821,20 +760,17 @@ def binary_operator_extractor( self, expr: BinaryExpression, df_aliased_col_name_to_real_col_name, - config_context: ConfigContext, parse_local_name=False, ) -> str: - if config_context.eliminate_numeric_sql_value_cast_enabled: + if self.session.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( expr.left, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, ) right_sql_expr = self.to_sql_try_avoid_cast( expr.right, df_aliased_col_name_to_real_col_name, - config_context, parse_local_name, ) else: @@ -842,13 +778,11 @@ def binary_operator_extractor( expr.left, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) right_sql_expr = self.analyze( expr.right, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) if isinstance(expr, BinaryArithmeticExpression): return binary_arithmetic_expression( @@ -867,7 +801,7 @@ def binary_operator_extractor( ) def grouping_extractor( - self, expr: GroupingSet, df_aliased_col_name_to_real_col_name, config_context + self, expr: GroupingSet, df_aliased_col_name_to_real_col_name ) -> str: return self.analyze( FunctionExpression( @@ -876,7 +810,6 @@ def grouping_extractor( False, ), df_aliased_col_name_to_real_col_name, - config_context=config_context, ) def window_frame_boundary(self, offset: str) -> str: @@ -890,7 +823,6 @@ def to_sql_try_avoid_cast( self, expr: Expression, df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], - config_context: ConfigContext, parse_local_name: bool = False, ) -> str: """ @@ -914,19 +846,18 @@ def to_sql_try_avoid_cast( expr, df_aliased_col_name_to_real_col_name, parse_local_name, - config_context=config_context, ) def resolve( - self, logical_plan: LogicalPlan, config_context: Optional[ConfigContext] = None + self, + logical_plan: LogicalPlan, ) -> SnowflakePlan: self.subquery_plans = [] self.generated_alias_maps = {} - if config_context is None: - config_context = ConfigContext(self.session) - self.plan_builder.set_config_context(config_context) - result = self.do_resolve(logical_plan, config_context) + result = self.do_resolve( + logical_plan, + ) result.add_aliases(self.generated_alias_maps) @@ -936,7 +867,8 @@ def resolve( return result def do_resolve( - self, logical_plan: LogicalPlan, config_context: ConfigContext + self, + logical_plan: LogicalPlan, ) -> SnowflakePlan: resolved_children = {} df_aliased_col_name_to_real_col_name: DefaultDict[ @@ -944,7 +876,9 @@ def do_resolve( ] = defaultdict(dict) for c in logical_plan.children: # post-order traversal of the tree - resolved = self.resolve(c, config_context) + resolved = self.resolve( + c, + ) df_aliased_col_name_to_real_col_name.update( resolved.df_aliased_col_name_to_real_col_name ) @@ -975,7 +909,6 @@ def do_resolve( logical_plan, resolved_children, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) res.df_aliased_col_name_to_real_col_name.update( df_aliased_col_name_to_real_col_name @@ -987,7 +920,6 @@ def do_resolve_with_resolved_children( logical_plan: LogicalPlan, resolved_children: Dict[LogicalPlan, SnowflakePlan], df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], - config_context: ConfigContext, ) -> SnowflakePlan: if isinstance(logical_plan, SnowflakePlan): return logical_plan @@ -997,7 +929,6 @@ def do_resolve_with_resolved_children( self.analyze( logical_plan.table_function, df_aliased_col_name_to_real_col_name, - config_context=config_context, ), resolved_children[logical_plan.children[0]], logical_plan, @@ -1011,7 +942,6 @@ def do_resolve_with_resolved_children( self.analyze( logical_plan.table_function, df_aliased_col_name_to_real_col_name, - config_context=config_context, ), logical_plan, ) @@ -1021,7 +951,6 @@ def do_resolve_with_resolved_children( self.analyze( logical_plan.table_function, df_aliased_col_name_to_real_col_name, - config_context=config_context, ), resolved_children[logical_plan.children[0]], logical_plan, @@ -1033,7 +962,6 @@ def do_resolve_with_resolved_children( self.to_sql_try_avoid_cast( expr, df_aliased_col_name_to_real_col_name, - config_context, ) for expr in logical_plan.grouping_expressions ], @@ -1041,7 +969,6 @@ def do_resolve_with_resolved_children( self.analyze( expr, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for expr in logical_plan.aggregate_expressions ], @@ -1056,7 +983,6 @@ def do_resolve_with_resolved_children( lambda x: self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context=config_context, ), logical_plan.project_list, ) @@ -1070,7 +996,6 @@ def do_resolve_with_resolved_children( self.analyze( logical_plan.condition, df_aliased_col_name_to_real_col_name, - config_context=config_context, ), resolved_children[logical_plan.child], logical_plan, @@ -1090,7 +1015,6 @@ def do_resolve_with_resolved_children( self.analyze( logical_plan.join_condition, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if logical_plan.join_condition else "" @@ -1099,7 +1023,6 @@ def do_resolve_with_resolved_children( self.analyze( logical_plan.match_condition, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if logical_plan.match_condition else "" @@ -1120,7 +1043,6 @@ def do_resolve_with_resolved_children( self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for x in logical_plan.order ], @@ -1188,7 +1110,6 @@ def do_resolve_with_resolved_children( self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for x in logical_plan.clustering_exprs ], @@ -1214,12 +1135,10 @@ def do_resolve_with_resolved_children( self.to_sql_try_avoid_cast( logical_plan.limit_expr, df_aliased_col_name_to_real_col_name, - config_context, ), self.to_sql_try_avoid_cast( logical_plan.offset_expr, df_aliased_col_name_to_real_col_name, - config_context, ), resolved_children[logical_plan.child], on_top_of_order_by, @@ -1250,7 +1169,6 @@ def do_resolve_with_resolved_children( self.analyze( col, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for col in project_exprs ], @@ -1270,7 +1188,6 @@ def do_resolve_with_resolved_children( self.analyze( pv, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for pv in logical_plan.pivot_values ] @@ -1278,7 +1195,6 @@ def do_resolve_with_resolved_children( pivot_values = self.analyze( logical_plan.pivot_values, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) else: pivot_values = None @@ -1287,18 +1203,15 @@ def do_resolve_with_resolved_children( self.analyze( logical_plan.pivot_column, df_aliased_col_name_to_real_col_name, - config_context=config_context, ), pivot_values, self.analyze( logical_plan.aggregates[0], df_aliased_col_name_to_real_col_name, - config_context=config_context, ), self.analyze( logical_plan.default_on_null, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if logical_plan.default_on_null else None, @@ -1326,7 +1239,6 @@ def do_resolve_with_resolved_children( self.analyze( c, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for c in logical_plan.column_list ], @@ -1372,7 +1284,6 @@ def do_resolve_with_resolved_children( self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for x in logical_plan.clustering_exprs ], @@ -1408,7 +1319,6 @@ def do_resolve_with_resolved_children( self.analyze( x, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for x in logical_plan.transformations ] @@ -1427,7 +1337,6 @@ def do_resolve_with_resolved_children( partition_by=self.analyze( logical_plan.partition_by, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if logical_plan.partition_by else None, @@ -1445,18 +1354,15 @@ def do_resolve_with_resolved_children( self.analyze( k, df_aliased_col_name_to_real_col_name, - config_context=config_context, ): self.analyze( v, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for k, v in logical_plan.assignments.items() }, self.analyze( logical_plan.condition, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if logical_plan.condition else None, @@ -1472,7 +1378,6 @@ def do_resolve_with_resolved_children( self.analyze( logical_plan.condition, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) if logical_plan.condition else None, @@ -1492,13 +1397,11 @@ def do_resolve_with_resolved_children( self.analyze( logical_plan.join_expr, df_aliased_col_name_to_real_col_name, - config_context=config_context, ), [ self.analyze( c, df_aliased_col_name_to_real_col_name, - config_context=config_context, ) for c in logical_plan.clauses ], diff --git a/src/snowflake/snowpark/_internal/analyzer/config_context.py b/src/snowflake/snowpark/_internal/analyzer/config_context.py deleted file mode 100644 index c45d335d8a2..00000000000 --- a/src/snowflake/snowpark/_internal/analyzer/config_context.py +++ /dev/null @@ -1,37 +0,0 @@ -# -# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. -# - -from typing import Any - - -class ConfigContext: - """Class to help reading a snapshot of configuration attributes from a session object. - - On instantiation, this object stores the configuration from the session object - and returns the stored configuration attributes when requested. - """ - - def __init__(self, session) -> None: - self.session = session - self.configs = { - "_query_compilation_stage_enabled", - "cte_optimization_enabled", - "eliminate_numeric_sql_value_cast_enabled", - "large_query_breakdown_complexity_bounds", - "large_query_breakdown_enabled", - } - self._create_snapshot() - - def __getattr__(self, name: str) -> Any: - if name in self.configs: - return getattr(self.session, name) - raise AttributeError(f"ConfigContext has no attribute {name}") - - def _create_snapshot(self) -> "ConfigContext": - """Reads the configuration attributes from the session object and stores them - in the context object. - """ - for name in self.configs: - setattr(self, name, getattr(self.session, name)) - return self diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 36c213b0530..bbda752e33d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -23,7 +23,6 @@ Union, ) -from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, sum_node_complexities, @@ -314,15 +313,15 @@ def children_plan_nodes(self) -> List[Union["Selectable", "SnowflakePlan"]]: return [] def replace_repeated_subquery_with_cte( - self, config_context: ConfigContext + self, ) -> "SnowflakePlan": # parameter protection # the common subquery elimination will be applied if cte_optimization is not enabled # and the new compilation stage is not enabled. When new compilation stage is enabled, # the common subquery elimination will be done through the new plan transformation. if ( - not config_context.cte_optimization_enabled - or config_context._query_compilation_stage_enabled + not self.session.cte_optimization_enabled + or self.session._query_compilation_stage_enabled ): return self @@ -358,16 +357,8 @@ def replace_repeated_subquery_with_cte( # create CTE query final_query = create_cte_query(self, duplicate_plan_set) - with self.session._lock: - # copy depends on the cte_optimization_enabled value. We should keep it - # consistent with the current context. - original_cte_optimization = self.session.cte_optimization_enabled - self.session.cte_optimization_enabled = ( - config_context.cte_optimization_enabled - ) - plan = copy.copy(self) - self.session.cte_optimization_enabled = original_cte_optimization # all other parts of query are unchanged, but just replace the original query + plan = copy.copy(self) plan.queries[-1].sql = final_query return plan @@ -542,10 +533,6 @@ def __init__( # on the optimized plan. During the final query generation, no schema query is needed, # this helps reduces un-necessary overhead for the describing call. self._skip_schema_query = skip_schema_query - self._config_context: Optional[ConfigContext] = None - - def set_config_context(self, config_context: ConfigContext) -> None: - self._config_context = config_context @SnowflakePlan.Decorator.wrap_exception def build( @@ -579,10 +566,9 @@ def build( ), "No schema query is available in child SnowflakePlan" new_schema_query = schema_query or sql_generator(child.schema_query) - config_context = self._config_context or ConfigContext(self.session) placeholder_query = ( sql_generator(select_child._id) - if config_context.cte_optimization_enabled and select_child._id is not None + if self.session.cte_optimization_enabled and select_child._id is not None else None ) @@ -619,10 +605,9 @@ def build_binary( right_schema_query = schema_value_statement(select_right.attributes) schema_query = sql_generator(left_schema_query, right_schema_query) - config_context = self._config_context or ConfigContext(self.session) placeholder_query = ( sql_generator(select_left._id, select_right._id) - if config_context.cte_optimization_enabled + if self.session.cte_optimization_enabled and select_left._id is not None and select_right._id is not None else None @@ -654,8 +639,8 @@ def build_binary( referenced_ctes: Set[str] = set() if ( - config_context.cte_optimization_enabled - and config_context._query_compilation_stage_enabled + self.session.cte_optimization_enabled + and self.session._query_compilation_stage_enabled ): # When the cte optimization and the new compilation stage is enabled, # the referred cte tables are propagated from left and right can have @@ -945,9 +930,7 @@ def save_as_table( column_definition_with_hidden_columns, ) - child = child.replace_repeated_subquery_with_cte( - self._config_context or ConfigContext(self.session) - ) + child = child.replace_repeated_subquery_with_cte() def get_create_table_as_select_plan(child: SnowflakePlan, replace, error): return self.build( @@ -1135,9 +1118,7 @@ def create_or_replace_view( if not is_sql_select_statement(child.queries[0].sql.lower().strip()): raise SnowparkClientExceptionMessages.PLAN_CREATE_VIEWS_FROM_SELECT_ONLY() - child = child.replace_repeated_subquery_with_cte( - self._config_context or ConfigContext(self.session) - ) + child = child.replace_repeated_subquery_with_cte() return self.build( lambda x: create_or_replace_view_statement(name, x, is_temp, comment), child, @@ -1180,9 +1161,7 @@ def create_or_replace_dynamic_table( # should never reach here raise ValueError(f"Unknown create mode: {create_mode}") # pragma: no cover - child = child.replace_repeated_subquery_with_cte( - self._config_context or ConfigContext(self.session) - ) + child = child.replace_repeated_subquery_with_cte() return self.build( lambda x: create_or_replace_dynamic_table_statement( name=name, @@ -1485,9 +1464,7 @@ def copy_into_location( header: bool = False, **copy_options: Optional[Any], ) -> SnowflakePlan: - query = query.replace_repeated_subquery_with_cte( - self._config_context or ConfigContext(self.session) - ) + query = query.replace_repeated_subquery_with_cte() return self.build( lambda x: copy_into_location( query=x, @@ -1514,9 +1491,7 @@ def update( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: if source_data: - source_data = source_data.replace_repeated_subquery_with_cte( - self._config_context or ConfigContext(self.session) - ) + source_data = source_data.replace_repeated_subquery_with_cte() return self.build( lambda x: update_statement( table_name, @@ -1547,9 +1522,7 @@ def delete( source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: if source_data: - source_data = source_data.replace_repeated_subquery_with_cte( - self._config_context or ConfigContext(self.session) - ) + source_data = source_data.replace_repeated_subquery_with_cte() return self.build( lambda x: delete_statement( table_name, @@ -1578,9 +1551,7 @@ def merge( clauses: List[str], source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: - source_data = source_data.replace_repeated_subquery_with_cte( - self._config_context or ConfigContext(self.session) - ) + source_data = source_data.replace_repeated_subquery_with_cte() return self.build( lambda x: merge_statement(table_name, x, join_expr, clauses), source_data, diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 7a2d878c191..38a1bf3acd8 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -6,7 +6,6 @@ import time from typing import Dict, List -from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( get_complexity_score, ) @@ -47,7 +46,6 @@ class PlanCompiler: def __init__(self, plan: SnowflakePlan) -> None: self._plan = plan - self._config_context = ConfigContext(self._plan.session) def should_start_query_compilation(self) -> bool: """ @@ -67,14 +65,15 @@ def should_start_query_compilation(self) -> bool: return ( not isinstance(current_session._conn, MockServerConnection) and (self._plan.source_plan is not None) - and self._config_context._query_compilation_stage_enabled + and self.current_session._query_compilation_stage_enabled and ( - self._config_context.cte_optimization_enabled - or self._config_context.large_query_breakdown_enabled + self.current_session.cte_optimization_enabled + or self.current_session.large_query_breakdown_enabled ) ) def compile(self) -> Dict[PlanQueryType, List[Query]]: + session = self._plan.session if self.should_start_query_compilation(): # preparation for compilation # 1. make a copy of the original plan @@ -86,12 +85,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: deep_copy_end_time = time.time() # 2. create a code generator with the original plan - query_generator = create_query_generator(self._plan, self._config_context) + query_generator = create_query_generator(self._plan) # 3. apply each optimizations if needed # CTE optimization cte_start_time = time.time() - if self._config_context.cte_optimization_enabled: + if session.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) @@ -104,12 +103,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: ] # Large query breakdown - if self._config_context.large_query_breakdown_enabled: + if session.large_query_breakdown_enabled: large_query_breakdown = LargeQueryBreakdown( self._plan.session, query_generator, logical_plans, - self._config_context.large_query_breakdown_complexity_bounds, + session.large_query_breakdown_complexity_bounds, ) logical_plans = large_query_breakdown.apply() @@ -127,11 +126,10 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: cte_time = cte_end_time - cte_start_time large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time total_time = time.time() - start_time - session = self._plan.session summary_value = { - TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self._config_context.cte_optimization_enabled, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self._config_context.large_query_breakdown_enabled, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self._config_context.large_query_breakdown_complexity_bounds, + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, @@ -148,9 +146,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: return queries else: final_plan = self._plan - final_plan = final_plan.replace_repeated_subquery_with_cte( - self._config_context - ) + final_plan = final_plan.replace_repeated_subquery_with_cte() return { PlanQueryType.QUERIES: final_plan.queries, PlanQueryType.POST_ACTIONS: final_plan.post_actions, diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index d8d2762f834..187629832e6 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -8,7 +8,6 @@ from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.analyzer.select_statement import Selectable from snowflake.snowpark._internal.analyzer.snowflake_plan import ( - ConfigContext, CreateViewCommand, PlanQueryType, Query, @@ -50,7 +49,6 @@ class QueryGenerator(Analyzer): def __init__( self, session: Session, - config_context: ConfigContext, snowflake_create_table_plan_info: Optional[SnowflakeCreateTablePlanInfo] = None, ) -> None: super().__init__(session) @@ -60,7 +58,6 @@ def __init__( self._snowflake_create_table_plan_info: Optional[ SnowflakeCreateTablePlanInfo ] = snowflake_create_table_plan_info - self._config_context = config_context # Records the definition of all the with query blocks encountered during the code generation. # This information will be used to generate the final query of a SnowflakePlan with the # correct CTE definition. @@ -109,21 +106,20 @@ def generate_queries( } def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: - return super().resolve(logical_plan, self._config_context) + return super().resolve(logical_plan) def do_resolve_with_resolved_children( self, logical_plan: LogicalPlan, resolved_children: Dict[LogicalPlan, SnowflakePlan], df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]], - config_context: ConfigContext, ) -> SnowflakePlan: if isinstance(logical_plan, SnowflakePlan): if logical_plan.queries is None: assert logical_plan.source_plan is not None # when encounter a SnowflakePlan with no queries, try to re-resolve # the source plan to construct the result - res = self.do_resolve(logical_plan.source_plan, config_context) + res = self.do_resolve(logical_plan.source_plan) resolved_children[logical_plan] = res resolved_plan = res else: @@ -214,7 +210,6 @@ def do_resolve_with_resolved_children( logical_plan, resolved_children, df_aliased_col_name_to_real_col_name, - config_context, ) elif isinstance(logical_plan, Selectable): @@ -241,7 +236,6 @@ def do_resolve_with_resolved_children( logical_plan, resolved_children, df_aliased_col_name_to_real_col_name, - config_context, ) resolved_plan._is_valid_for_replacement = True diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index fecf6c36a3e..bf52ca89b79 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Union from snowflake.snowpark._internal.analyzer.binary_plan_node import BinaryNode -from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.select_statement import ( Selectable, SelectSnowflakePlan, @@ -41,7 +40,7 @@ def create_query_generator( - plan: SnowflakePlan, config_context: ConfigContext + plan: SnowflakePlan, ) -> QueryGenerator: """ Helper function to construct the query generator for a given valid SnowflakePlan. @@ -67,16 +66,12 @@ def create_query_generator( # resolved plan, and the resolve will be a no-op. # NOTE that here we rely on the fact that the SnowflakeCreateTable node is the root # of a source plan. Test will fail if that assumption is broken. - resolved_child = plan.session._analyzer.resolve( - create_table_node.query, config_context - ) + resolved_child = plan.session._analyzer.resolve(create_table_node.query) snowflake_create_table_plan_info = SnowflakeCreateTablePlanInfo( create_table_node.table_name, resolved_child.attributes ) - return QueryGenerator( - plan.session, config_context, snowflake_create_table_plan_info - ) + return QueryGenerator(plan.session, snowflake_create_table_plan_info) def resolve_and_update_snowflake_plan( diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index 3f59859ab94..df4ec62d4a8 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -44,7 +44,6 @@ BinaryExpression, ) from snowflake.snowpark._internal.analyzer.binary_plan_node import Join, SetOperation -from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.datatype_mapper import ( numeric_to_sql_without_cast, str_to_sql, @@ -161,7 +160,6 @@ def analyze( expr_to_alias: Optional[Dict[str, str]] = None, parse_local_name=False, keep_alias=True, - config_context: Optional[ConfigContext] = None, ) -> Union[str, List[str]]: """ Args: @@ -174,8 +172,6 @@ def analyze( """ if expr_to_alias is None: expr_to_alias = {} - if config_context is None: - config_context = ConfigContext(self.session) if isinstance(expr, GroupingSetsExpression): self._conn.log_not_supported_error( external_feature_name="DataFrame.group_by_grouping_sets", @@ -188,13 +184,11 @@ def analyze( expr.expr, expr_to_alias, parse_local_name, - config_context=config_context, ), self.analyze( expr.pattern, expr_to_alias, parse_local_name, - config_context=config_context, ), ) @@ -204,19 +198,16 @@ def analyze( expr.expr, expr_to_alias, parse_local_name, - config_context=config_context, ), self.analyze( expr.pattern, expr_to_alias, parse_local_name, - config_context=config_context, ), self.analyze( expr.parameters, expr_to_alias, parse_local_name, - config_context=config_context, ) if expr.parameters is not None else None, @@ -231,7 +222,6 @@ def analyze( expr.expr, expr_to_alias, parse_local_name, - config_context=config_context, ), collation_spec, ) @@ -245,7 +235,6 @@ def analyze( expr.expr, expr_to_alias, parse_local_name, - config_context=config_context, ), field, ) @@ -258,13 +247,11 @@ def analyze( condition, expr_to_alias, parse_local_name, - config_context=config_context, ), self.analyze( value, expr_to_alias, parse_local_name, - config_context=config_context, ), ) for condition, value in expr.branches @@ -273,7 +260,6 @@ def analyze( expr.else_value, expr_to_alias, parse_local_name, - config_context=config_context, ) if expr.else_value else "NULL", @@ -282,11 +268,10 @@ def analyze( if isinstance(expr, MultipleExpression): block_expressions = [] for expression in expr.expressions: - if config_context.eliminate_numeric_sql_value_cast_enabled: + if self.session.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( expression, expr_to_alias, - config_context, parse_local_name, ) else: @@ -294,7 +279,6 @@ def analyze( expression, expr_to_alias, parse_local_name, - config_context=config_context, ) block_expressions.append(resolved_expr) @@ -303,11 +287,10 @@ def analyze( if isinstance(expr, InExpression): in_values = [] for expression in expr.values: - if config_context.eliminate_numeric_sql_value_cast_enabled: + if self.session.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( expression, expr_to_alias, - config_context, parse_local_name, ) else: @@ -315,7 +298,6 @@ def analyze( expression, expr_to_alias, parse_local_name, - config_context=config_context, ) in_values.append(in_value) @@ -324,7 +306,6 @@ def analyze( expr.columns, expr_to_alias, parse_local_name, - config_context=config_context, ), in_values, ) @@ -340,12 +321,10 @@ def analyze( self.analyze( expr.window_function, parse_local_name=parse_local_name, - config_context=config_context, ), self.analyze( expr.window_spec, parse_local_name=parse_local_name, - config_context=config_context, ), ) @@ -355,7 +334,6 @@ def analyze( self.analyze( x, parse_local_name=parse_local_name, - config_context=config_context, ) for x in expr.partition_spec ], @@ -363,14 +341,12 @@ def analyze( self.analyze( x, parse_local_name=parse_local_name, - config_context=config_context, ) for x in expr.order_spec ], self.analyze( expr.frame_spec, parse_local_name=parse_local_name, - config_context=config_context, ), ) @@ -378,10 +354,16 @@ def analyze( return specified_window_frame_expression( expr.frame_type.sql, self.window_frame_boundary( - self.to_sql_try_avoid_cast(expr.lower, {}, config_context) + self.to_sql_try_avoid_cast( + expr.lower, + {}, + ) ), self.window_frame_boundary( - self.to_sql_try_avoid_cast(expr.upper, {}, config_context) + self.to_sql_try_avoid_cast( + expr.upper, + {}, + ) ), ) @@ -412,7 +394,10 @@ def analyze( children = [] for c in expr.children: - extracted = self.to_sql_try_avoid_cast(c, expr_to_alias, config_context) + extracted = self.to_sql_try_avoid_cast( + c, + expr_to_alias, + ) if isinstance(extracted, list): children.extend(extracted) else: @@ -429,7 +414,10 @@ def analyze( return "*" else: return [ - self.analyze(e, expr_to_alias, config_context=config_context) + self.analyze( + e, + expr_to_alias, + ) for e in expr.expressions ] @@ -446,7 +434,6 @@ def analyze( x, expr_to_alias, parse_local_name, - config_context=config_context, ) for x in expr.children ], @@ -455,7 +442,8 @@ def analyze( if isinstance(expr, TableFunctionExpression): return self.table_function_expression_extractor( - expr, expr_to_alias, config_context + expr, + expr_to_alias, ) if isinstance(expr, TableFunctionPartitionSpecDefinition): @@ -466,7 +454,6 @@ def analyze( x, expr_to_alias, parse_local_name, - config_context=config_context, ) for x in expr.partition_spec ] @@ -477,7 +464,6 @@ def analyze( x, expr_to_alias, parse_local_name, - config_context=config_context, ) for x in expr.order_spec ] @@ -489,7 +475,6 @@ def analyze( return self.unary_expression_extractor( expr, expr_to_alias, - config_context, parse_local_name, keep_alias=keep_alias, ) @@ -500,7 +485,6 @@ def analyze( expr.child, expr_to_alias, parse_local_name, - config_context=config_context, ), expr.direction.sql, expr.null_ordering.sql, @@ -516,10 +500,12 @@ def analyze( expr.expr, expr_to_alias, parse_local_name, - config_context=config_context, ), [ - self.analyze(e, expr_to_alias, config_context=config_context) + self.analyze( + e, + expr_to_alias, + ) for e in expr.order_by_cols ], ) @@ -528,23 +514,29 @@ def analyze( return self.binary_operator_extractor( expr, expr_to_alias, - config_context, parse_local_name, ) if isinstance(expr, InsertMergeExpression): return insert_merge_statement( self.analyze( - expr.condition, expr_to_alias, config_context=config_context + expr.condition, + expr_to_alias, ) if expr.condition else None, [ - self.analyze(k, expr_to_alias, config_context=config_context) + self.analyze( + k, + expr_to_alias, + ) for k in expr.keys ], [ - self.analyze(v, expr_to_alias, config_context=config_context) + self.analyze( + v, + expr_to_alias, + ) for v in expr.values ], ) @@ -552,14 +544,16 @@ def analyze( if isinstance(expr, UpdateMergeExpression): return update_merge_statement( self.analyze( - expr.condition, expr_to_alias, config_context=config_context + expr.condition, + expr_to_alias, ) if expr.condition else None, { - self.analyze( - k, expr_to_alias, config_context=config_context - ): self.analyze(v, expr_to_alias, config_context=config_context) + self.analyze(k, expr_to_alias,): self.analyze( + v, + expr_to_alias, + ) for k, v in expr.assignments.items() }, ) @@ -567,7 +561,8 @@ def analyze( if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( self.analyze( - expr.condition, expr_to_alias, config_context=config_context + expr.condition, + expr_to_alias, ) if expr.condition else None @@ -579,7 +574,6 @@ def analyze( expr.col, expr_to_alias, parse_local_name, - config_context=config_context, ), str_to_sql(expr.delimiter), expr.is_distinct, @@ -592,7 +586,6 @@ def analyze( col, expr_to_alias, parse_local_name, - config_context=config_context, ) for col in expr.exprs ] @@ -605,14 +598,12 @@ def analyze( expr.expr, expr_to_alias, parse_local_name, - config_context=config_context, ), expr.offset, self.analyze( expr.default, expr_to_alias, parse_local_name, - config_context=config_context, ) if expr.default else None, @@ -627,7 +618,6 @@ def table_function_expression_extractor( self, expr: TableFunctionExpression, expr_to_alias: Dict[str, str], - config_context: ConfigContext, parse_local_name=False, ) -> str: if isinstance(expr, FlattenFunction): @@ -636,7 +626,6 @@ def table_function_expression_extractor( expr.input, expr_to_alias, parse_local_name, - config_context=config_context, ), expr.path, expr.outer, @@ -651,7 +640,6 @@ def table_function_expression_extractor( x, expr_to_alias, parse_local_name, - config_context=config_context, ) for x in expr.args ], @@ -665,7 +653,6 @@ def table_function_expression_extractor( value, expr_to_alias, parse_local_name, - config_context=config_context, ) for key, value in expr.args.items() }, @@ -677,7 +664,8 @@ def table_function_expression_extractor( ) partition_spec_sql = ( self.analyze( - expr.partition_spec, expr_to_alias, config_context=config_context + expr.partition_spec, + expr_to_alias, ) if expr.partition_spec else "" @@ -688,7 +676,6 @@ def unary_expression_extractor( self, expr: UnaryExpression, expr_to_alias: Dict[str, str], - config_context: ConfigContext, parse_local_name=False, keep_alias=True, ) -> str: @@ -704,7 +691,6 @@ def unary_expression_extractor( expr.child, expr_to_alias, parse_local_name, - config_context=config_context, ), quoted_name, ) @@ -717,7 +703,6 @@ def unary_expression_extractor( expr.child, expr_to_alias, parse_local_name, - config_context=config_context, ) if parse_local_name: expr_str = expr_str.upper() @@ -728,7 +713,6 @@ def unary_expression_extractor( expr.child, expr_to_alias, parse_local_name, - config_context=config_context, ), expr.to, expr.try_, @@ -739,7 +723,6 @@ def unary_expression_extractor( expr.child, expr_to_alias, parse_local_name, - config_context=config_context, ), expr.sql_operator, expr.operator_first, @@ -749,20 +732,17 @@ def binary_operator_extractor( self, expr: BinaryExpression, expr_to_alias: Dict[str, str], - config_context: ConfigContext, parse_local_name=False, ) -> str: - if config_context.eliminate_numeric_sql_value_cast_enabled: + if self.session.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( expr.left, expr_to_alias, - config_context, parse_local_name, ) right_sql_expr = self.to_sql_try_avoid_cast( expr.right, expr_to_alias, - config_context, parse_local_name, ) else: @@ -770,13 +750,11 @@ def binary_operator_extractor( expr.left, expr_to_alias, parse_local_name, - config_context=config_context, ) right_sql_expr = self.analyze( expr.right, expr_to_alias, parse_local_name, - config_context=config_context, ) operator = expr.sql_operator.lower() @@ -800,7 +778,6 @@ def grouping_extractor( self, expr: GroupingSet, expr_to_alias: Dict[str, str], - config_context: ConfigContext, ) -> str: return self.analyze( FunctionExpression( @@ -809,7 +786,6 @@ def grouping_extractor( False, ), expr_to_alias, - config_context=config_context, ) def window_frame_boundary(self, offset: str) -> str: @@ -823,7 +799,6 @@ def to_sql_try_avoid_cast( self, expr: Expression, expr_to_alias: Dict[str, str], - config_context: ConfigContext, parse_local_name=False, ) -> str: # if expression is a numeric literal, return the number without casting, @@ -832,23 +807,24 @@ def to_sql_try_avoid_cast( return numeric_to_sql_without_cast(expr.value, expr.datatype) else: return self.analyze( - expr, expr_to_alias, parse_local_name, config_context=config_context + expr, + expr_to_alias, + parse_local_name, ) def resolve( self, logical_plan: LogicalPlan, expr_to_alias: Optional[Dict[str, str]] = None, - config_context: Optional[ConfigContext] = None, ) -> MockExecutionPlan: self.subquery_plans = [] if expr_to_alias is None: expr_to_alias = {} - if config_context is None: - config_context = ConfigContext(self.session) - - result = self.do_resolve(logical_plan, expr_to_alias, config_context) + result = self.do_resolve( + logical_plan, + expr_to_alias, + ) return result @@ -856,13 +832,15 @@ def do_resolve( self, logical_plan: LogicalPlan, expr_to_alias: Dict[str, str], - config_context: ConfigContext, ) -> MockExecutionPlan: resolved_children = {} expr_to_alias_maps = {} for c in logical_plan.children: _expr_to_alias = {} - resolved_children[c] = self.resolve(c, _expr_to_alias, config_context) + resolved_children[c] = self.resolve( + c, + _expr_to_alias, + ) expr_to_alias_maps[c] = _expr_to_alias # get counts of expr_to_alias keys @@ -879,7 +857,6 @@ def do_resolve( logical_plan, resolved_children, expr_to_alias, - config_context=config_context, ) def do_resolve_with_resolved_children( @@ -887,7 +864,6 @@ def do_resolve_with_resolved_children( logical_plan: LogicalPlan, resolved_children: Dict[LogicalPlan, SnowflakePlan], expr_to_alias: Dict[str, str], - config_context: ConfigContext, ) -> MockExecutionPlan: if isinstance(logical_plan, MockExecutionPlan): return logical_plan @@ -976,10 +952,12 @@ def do_resolve_with_resolved_children( ) and isinstance(logical_plan.child.source_plan, Sort) return self.plan_builder.limit( self.to_sql_try_avoid_cast( - logical_plan.limit_expr, expr_to_alias, config_context + logical_plan.limit_expr, + expr_to_alias, ), self.to_sql_try_avoid_cast( - logical_plan.offset_expr, expr_to_alias, config_context + logical_plan.offset_expr, + expr_to_alias, ), resolved_children[logical_plan.child], on_top_of_order_by, @@ -1012,7 +990,6 @@ def do_resolve_with_resolved_children( partition_by=self.analyze( logical_plan.partition_by, expr_to_alias, - config_context=config_context, ) if logical_plan.partition_by else None, diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 3ef92a63f39..ea3ae1730e9 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -11,7 +11,6 @@ import pytest -from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType @@ -503,52 +502,3 @@ def finish(self): with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): executor.submit(register_and_test_udaf, session, i) - - -def test_config_context(session): - try: - original_cte_optimization = session.cte_optimization_enabled - original_eliminate_numeric_sql_value_cast_enabled = ( - session.eliminate_numeric_sql_value_cast_enabled - ) - original_query_compilation_stage_enabled = ( - session._query_compilation_stage_enabled - ) - config_context = ConfigContext(session) - - # Check if all context configs are present in the session - for name in config_context.configs: - assert hasattr(session, name) - - # change session configs - session.cte_optimization_enabled = not original_cte_optimization - session.eliminate_numeric_sql_value_cast_enabled = ( - not original_eliminate_numeric_sql_value_cast_enabled - ) - session._query_compilation_stage_enabled = ( - not original_query_compilation_stage_enabled - ) - - # assert we read original config values - assert config_context.cte_optimization_enabled == original_cte_optimization - assert ( - config_context.eliminate_numeric_sql_value_cast_enabled - == original_eliminate_numeric_sql_value_cast_enabled - ) - assert ( - config_context._query_compilation_stage_enabled - == original_query_compilation_stage_enabled - ) - - with pytest.raises( - AttributeError, match="ConfigContext has no attribute no_such_config" - ): - config_context.no_such_config - finally: - session.cte_optimization_enabled = original_cte_optimization - session.eliminate_numeric_sql_value_cast_enabled = ( - original_eliminate_numeric_sql_value_cast_enabled - ) - session._query_compilation_stage_enabled = ( - original_query_compilation_stage_enabled - ) diff --git a/tests/unit/compiler/test_replace_child_and_update_node.py b/tests/unit/compiler/test_replace_child_and_update_node.py index 558e47530c3..05098165a1b 100644 --- a/tests/unit/compiler/test_replace_child_and_update_node.py +++ b/tests/unit/compiler/test_replace_child_and_update_node.py @@ -8,7 +8,6 @@ import pytest from snowflake.snowpark._internal.analyzer.binary_plan_node import Inner, Join, Union -from snowflake.snowpark._internal.analyzer.config_context import ConfigContext from snowflake.snowpark._internal.analyzer.select_statement import ( Selectable, SelectableEntity, @@ -71,7 +70,7 @@ def mock_snowflake_plan() -> SnowflakePlan: @pytest.fixture(scope="function") def mock_query_generator(mock_session) -> QueryGenerator: - def mock_resolve(x, y): + def mock_resolve(x): snowflake_plan = mock_snowflake_plan() snowflake_plan.source_plan = x if hasattr(x, "post_actions"): @@ -81,9 +80,6 @@ def mock_resolve(x, y): fake_query_generator = mock.create_autospec(QueryGenerator) fake_query_generator.resolve.side_effect = mock_resolve fake_query_generator.session = mock_session - fake_config_context = mock.create_autospec(ConfigContext) - fake_config_context._query_compilation_stage_enabled = False - fake_query_generator.config_context = fake_config_context return fake_query_generator diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index c6c0b0cb508..986927b65e4 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -48,7 +48,7 @@ def mock_snowflake_plan(mock_query) -> Analyzer: @pytest.fixture(scope="module") def mock_analyzer(mock_snowflake_plan) -> Analyzer: - def mock_resolve(x, y=None): + def mock_resolve(x): mock_snowflake_plan.source_plan = x return mock_snowflake_plan From 32707f9eadc9f39c5499688315caf1288809619c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Mon, 30 Sep 2024 10:43:31 -0700 Subject: [PATCH 51/62] min-diff --- .../snowpark/_internal/analyzer/analyzer.py | 132 +++-------- src/snowflake/snowpark/mock/_analyzer.py | 215 +++--------------- 2 files changed, 70 insertions(+), 277 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 7ca2b883ef0..277eadb4c91 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -178,9 +178,7 @@ def analyze( [ [ self.analyze( - a, - df_aliased_col_name_to_real_col_name, - parse_local_name, + a, df_aliased_col_name_to_real_col_name, parse_local_name ) for a in arg ] @@ -191,28 +189,20 @@ def analyze( if isinstance(expr, Like): return like_expression( self.analyze( - expr.expr, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), self.analyze( - expr.pattern, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name ), ) if isinstance(expr, RegExp): return regexp_expression( self.analyze( - expr.expr, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), self.analyze( - expr.pattern, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name ), self.analyze( expr.parameters, @@ -229,9 +219,7 @@ def analyze( ) return collate_expression( self.analyze( - expr.expr, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), collation_spec, ) @@ -242,9 +230,7 @@ def analyze( field = field.upper() return subfield_expression( self.analyze( - expr.expr, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), field, ) @@ -313,9 +299,7 @@ def analyze( in_values.append(in_value) return in_expression( self.analyze( - expr.columns, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.columns, df_aliased_col_name_to_real_col_name, parse_local_name ), in_values, ) @@ -340,17 +324,13 @@ def analyze( return window_spec_expression( [ self.analyze( - x, - df_aliased_col_name_to_real_col_name, - parse_local_name, + x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.partition_spec ], [ self.analyze( - x, - df_aliased_col_name_to_real_col_name, - parse_local_name, + x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.order_spec ], @@ -453,9 +433,7 @@ def analyze( func_name, [ self.analyze( - x, - df_aliased_col_name_to_real_col_name, - parse_local_name, + x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.children ], @@ -476,9 +454,7 @@ def analyze( expr.over, [ self.analyze( - x, - df_aliased_col_name_to_real_col_name, - parse_local_name, + x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.partition_spec ] @@ -486,9 +462,7 @@ def analyze( else [], [ self.analyze( - x, - df_aliased_col_name_to_real_col_name, - parse_local_name, + x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.order_spec ] @@ -498,17 +472,13 @@ def analyze( if isinstance(expr, UnaryExpression): return self.unary_expression_extractor( - expr, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr, df_aliased_col_name_to_real_col_name, parse_local_name ) if isinstance(expr, SortOrder): return order_expression( self.analyze( - expr.child, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.direction.sql, expr.null_ordering.sql, @@ -521,9 +491,7 @@ def analyze( if isinstance(expr, WithinGroup): return within_group_expression( self.analyze( - expr.expr, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), [ self.analyze( @@ -536,9 +504,7 @@ def analyze( if isinstance(expr, BinaryExpression): return self.binary_operator_extractor( - expr, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr, df_aliased_col_name_to_real_col_name, parse_local_name ) if isinstance(expr, InsertMergeExpression): @@ -598,9 +564,7 @@ def analyze( if isinstance(expr, ListAgg): return list_agg( self.analyze( - expr.col, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.col, df_aliased_col_name_to_real_col_name, parse_local_name ), str_to_sql(expr.delimiter), expr.is_distinct, @@ -610,9 +574,7 @@ def analyze( return column_sum( [ self.analyze( - col, - df_aliased_col_name_to_real_col_name, - parse_local_name, + col, df_aliased_col_name_to_real_col_name, parse_local_name ) for col in expr.exprs ] @@ -622,15 +584,11 @@ def analyze( return rank_related_function_expression( expr.sql, self.analyze( - expr.expr, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.offset, self.analyze( - expr.default, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.default, df_aliased_col_name_to_real_col_name, parse_local_name ) if expr.default else None, @@ -650,9 +608,7 @@ def table_function_expression_extractor( if isinstance(expr, FlattenFunction): return flatten_expression( self.analyze( - expr.input, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.input, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.path, expr.outer, @@ -664,9 +620,7 @@ def table_function_expression_extractor( expr.func_name, [ self.analyze( - x, - df_aliased_col_name_to_real_col_name, - parse_local_name, + x, df_aliased_col_name_to_real_col_name, parse_local_name ) for x in expr.args ], @@ -677,9 +631,7 @@ def table_function_expression_extractor( expr.func_name, { key: self.to_sql_try_avoid_cast( - value, - df_aliased_col_name_to_real_col_name, - parse_local_name, + value, df_aliased_col_name_to_real_col_name, parse_local_name ) for key, value in expr.args.items() }, @@ -720,17 +672,13 @@ def unary_expression_extractor( df_alias_dict[k] = quoted_name return alias_expression( self.analyze( - expr.child, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ), quoted_name, ) if isinstance(expr, UnresolvedAlias): expr_str = self.analyze( - expr.child, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ) if parse_local_name: expr_str = expr_str.upper() @@ -738,9 +686,7 @@ def unary_expression_extractor( elif isinstance(expr, Cast): return cast_expression( self.analyze( - expr.child, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.to, expr.try_, @@ -748,9 +694,7 @@ def unary_expression_extractor( else: return unary_expression( self.analyze( - expr.child, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.child, df_aliased_col_name_to_real_col_name, parse_local_name ), expr.sql_operator, expr.operator_first, @@ -764,25 +708,17 @@ def binary_operator_extractor( ) -> str: if self.session.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( - expr.left, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.left, df_aliased_col_name_to_real_col_name, parse_local_name ) right_sql_expr = self.to_sql_try_avoid_cast( - expr.right, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.right, df_aliased_col_name_to_real_col_name, parse_local_name ) else: left_sql_expr = self.analyze( - expr.left, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.left, df_aliased_col_name_to_real_col_name, parse_local_name ) right_sql_expr = self.analyze( - expr.right, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr.right, df_aliased_col_name_to_real_col_name, parse_local_name ) if isinstance(expr, BinaryArithmeticExpression): return binary_arithmetic_expression( @@ -843,9 +779,7 @@ def to_sql_try_avoid_cast( return str(expr.value).upper() else: return self.analyze( - expr, - df_aliased_col_name_to_real_col_name, - parse_local_name, + expr, df_aliased_col_name_to_real_col_name, parse_local_name ) def resolve( diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index df4ec62d4a8..e4cd33caf15 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -180,35 +180,15 @@ def analyze( if isinstance(expr, Like): return like_expression( - self.analyze( - expr.expr, - expr_to_alias, - parse_local_name, - ), - self.analyze( - expr.pattern, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.expr, expr_to_alias, parse_local_name), + self.analyze(expr.pattern, expr_to_alias, parse_local_name), ) if isinstance(expr, RegExp): return regexp_expression( - self.analyze( - expr.expr, - expr_to_alias, - parse_local_name, - ), - self.analyze( - expr.pattern, - expr_to_alias, - parse_local_name, - ), - self.analyze( - expr.parameters, - expr_to_alias, - parse_local_name, - ) + self.analyze(expr.expr, expr_to_alias, parse_local_name), + self.analyze(expr.pattern, expr_to_alias, parse_local_name), + self.analyze(expr.parameters, expr_to_alias, parse_local_name) if expr.parameters is not None else None, ) @@ -218,11 +198,7 @@ def analyze( expr.collation_spec.upper() if parse_local_name else expr.collation_spec ) return collate_expression( - self.analyze( - expr.expr, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.expr, expr_to_alias, parse_local_name), collation_spec, ) @@ -231,11 +207,7 @@ def analyze( if parse_local_name and isinstance(field, str): field = field.upper() return subfield_expression( - self.analyze( - expr.expr, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.expr, expr_to_alias, parse_local_name), field, ) @@ -243,24 +215,12 @@ def analyze( return case_when_expression( [ ( - self.analyze( - condition, - expr_to_alias, - parse_local_name, - ), - self.analyze( - value, - expr_to_alias, - parse_local_name, - ), + self.analyze(condition, expr_to_alias, parse_local_name), + self.analyze(value, expr_to_alias, parse_local_name), ) for condition, value in expr.branches ], - self.analyze( - expr.else_value, - expr_to_alias, - parse_local_name, - ) + self.analyze(expr.else_value, expr_to_alias, parse_local_name) if expr.else_value else "NULL", ) @@ -270,15 +230,11 @@ def analyze( for expression in expr.expressions: if self.session.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( - expression, - expr_to_alias, - parse_local_name, + expression, expr_to_alias, parse_local_name ) else: resolved_expr = self.analyze( - expression, - expr_to_alias, - parse_local_name, + expression, expr_to_alias, parse_local_name ) block_expressions.append(resolved_expr) @@ -289,24 +245,14 @@ def analyze( for expression in expr.values: if self.session.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( - expression, - expr_to_alias, - parse_local_name, + expression, expr_to_alias, parse_local_name ) else: - in_value = self.analyze( - expression, - expr_to_alias, - parse_local_name, - ) + in_value = self.analyze(expression, expr_to_alias, parse_local_name) in_values.append(in_value) return in_expression( - self.analyze( - expr.columns, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.columns, expr_to_alias, parse_local_name), in_values, ) @@ -430,11 +376,7 @@ def analyze( return function_expression( func_name, [ - self.analyze( - x, - expr_to_alias, - parse_local_name, - ) + self.analyze(x, expr_to_alias, parse_local_name) for x in expr.children ], False, @@ -450,21 +392,13 @@ def analyze( return table_function_partition_spec( expr.over, [ - self.analyze( - x, - expr_to_alias, - parse_local_name, - ) + self.analyze(x, expr_to_alias, parse_local_name) for x in expr.partition_spec ] if expr.partition_spec else [], [ - self.analyze( - x, - expr_to_alias, - parse_local_name, - ) + self.analyze(x, expr_to_alias, parse_local_name) for x in expr.order_spec ] if expr.order_spec @@ -481,11 +415,7 @@ def analyze( if isinstance(expr, SortOrder): return order_expression( - self.analyze( - expr.child, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.child, expr_to_alias, parse_local_name), expr.direction.sql, expr.null_ordering.sql, ) @@ -496,11 +426,7 @@ def analyze( if isinstance(expr, WithinGroup): return within_group_expression( - self.analyze( - expr.expr, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.expr, expr_to_alias, parse_local_name), [ self.analyze( e, @@ -511,11 +437,7 @@ def analyze( ) if isinstance(expr, BinaryExpression): - return self.binary_operator_extractor( - expr, - expr_to_alias, - parse_local_name, - ) + return self.binary_operator_extractor(expr, expr_to_alias, parse_local_name) if isinstance(expr, InsertMergeExpression): return insert_merge_statement( @@ -570,11 +492,7 @@ def analyze( if isinstance(expr, ListAgg): return list_agg( - self.analyze( - expr.col, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.col, expr_to_alias, parse_local_name), str_to_sql(expr.delimiter), expr.is_distinct, ) @@ -582,11 +500,7 @@ def analyze( if isinstance(expr, ColumnSum): return column_sum( [ - self.analyze( - col, - expr_to_alias, - parse_local_name, - ) + self.analyze(col, expr_to_alias, parse_local_name) for col in expr.exprs ] ) @@ -594,17 +508,9 @@ def analyze( if isinstance(expr, RankRelatedFunctionExpression): return rank_related_function_expression( expr.sql, - self.analyze( - expr.expr, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.expr, expr_to_alias, parse_local_name), expr.offset, - self.analyze( - expr.default, - expr_to_alias, - parse_local_name, - ) + self.analyze(expr.default, expr_to_alias, parse_local_name) if expr.default else None, expr.ignore_nulls, @@ -622,11 +528,7 @@ def table_function_expression_extractor( ) -> str: if isinstance(expr, FlattenFunction): return flatten_expression( - self.analyze( - expr.input, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.input, expr_to_alias, parse_local_name), expr.path, expr.outer, expr.recursive, @@ -635,25 +537,14 @@ def table_function_expression_extractor( elif isinstance(expr, PosArgumentsTableFunction): sql = function_expression( expr.func_name, - [ - self.analyze( - x, - expr_to_alias, - parse_local_name, - ) - for x in expr.args - ], + [self.analyze(x, expr_to_alias, parse_local_name) for x in expr.args], False, ) elif isinstance(expr, (NamedArgumentsTableFunction, GeneratorTableFunction)): sql = named_arguments_function( expr.func_name, { - key: self.analyze( - value, - expr_to_alias, - parse_local_name, - ) + key: self.analyze(value, expr_to_alias, parse_local_name) for key, value in expr.args.items() }, ) @@ -687,11 +578,7 @@ def unary_expression_extractor( if v == expr.child.name: expr_to_alias[k] = quoted_name alias_exp = alias_expression( - self.analyze( - expr.child, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.child, expr_to_alias, parse_local_name), quoted_name, ) @@ -699,31 +586,19 @@ def unary_expression_extractor( expr_str = expr_str.upper() if parse_local_name else expr_str return expr_str if isinstance(expr, UnresolvedAlias): - expr_str = self.analyze( - expr.child, - expr_to_alias, - parse_local_name, - ) + expr_str = self.analyze(expr.child, expr_to_alias, parse_local_name) if parse_local_name: expr_str = expr_str.upper() return quote_name(expr_str.strip()) elif isinstance(expr, Cast): return cast_expression( - self.analyze( - expr.child, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.child, expr_to_alias, parse_local_name), expr.to, expr.try_, ) else: return unary_expression( - self.analyze( - expr.child, - expr_to_alias, - parse_local_name, - ), + self.analyze(expr.child, expr_to_alias, parse_local_name), expr.sql_operator, expr.operator_first, ) @@ -736,26 +611,14 @@ def binary_operator_extractor( ) -> str: if self.session.eliminate_numeric_sql_value_cast_enabled: left_sql_expr = self.to_sql_try_avoid_cast( - expr.left, - expr_to_alias, - parse_local_name, + expr.left, expr_to_alias, parse_local_name ) right_sql_expr = self.to_sql_try_avoid_cast( - expr.right, - expr_to_alias, - parse_local_name, + expr.right, expr_to_alias, parse_local_name ) else: - left_sql_expr = self.analyze( - expr.left, - expr_to_alias, - parse_local_name, - ) - right_sql_expr = self.analyze( - expr.right, - expr_to_alias, - parse_local_name, - ) + left_sql_expr = self.analyze(expr.left, expr_to_alias, parse_local_name) + right_sql_expr = self.analyze(expr.right, expr_to_alias, parse_local_name) operator = expr.sql_operator.lower() if isinstance(expr, BinaryArithmeticExpression): @@ -806,11 +669,7 @@ def to_sql_try_avoid_cast( if isinstance(expr, Literal) and isinstance(expr.datatype, _NumericType): return numeric_to_sql_without_cast(expr.value, expr.datatype) else: - return self.analyze( - expr, - expr_to_alias, - parse_local_name, - ) + return self.analyze(expr, expr_to_alias, parse_local_name) def resolve( self, From 3bf678deb0438fb71f250e49a7f9a962fe340d7b Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Mon, 30 Sep 2024 10:54:07 -0700 Subject: [PATCH 52/62] min-diff --- .../snowpark/_internal/analyzer/analyzer.py | 183 +++++------------- .../_internal/analyzer/snowflake_plan.py | 10 +- .../_internal/compiler/plan_compiler.py | 8 +- .../_internal/compiler/query_generator.py | 11 +- .../snowpark/_internal/compiler/utils.py | 2 +- src/snowflake/snowpark/mock/_analyzer.py | 181 +++++------------ 6 files changed, 108 insertions(+), 287 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 277eadb4c91..d8622299ea9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -415,10 +415,7 @@ def analyze( # This case is hit by df.col("*") return ",".join( [ - self.analyze( - e, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(e, df_aliased_col_name_to_real_col_name) for e in expr.expressions ] ) @@ -494,10 +491,7 @@ def analyze( expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name ), [ - self.analyze( - e, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(e, df_aliased_col_name_to_real_col_name) for e in expr.order_by_cols ], ) @@ -509,43 +503,27 @@ def analyze( if isinstance(expr, InsertMergeExpression): return insert_merge_statement( - self.analyze( - expr.condition, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(expr.condition, df_aliased_col_name_to_real_col_name) if expr.condition else None, [ - self.analyze( - k, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(k, df_aliased_col_name_to_real_col_name) for k in expr.keys ], [ - self.analyze( - v, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(v, df_aliased_col_name_to_real_col_name) for v in expr.values ], ) if isinstance(expr, UpdateMergeExpression): return update_merge_statement( - self.analyze( - expr.condition, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(expr.condition, df_aliased_col_name_to_real_col_name) if expr.condition else None, { - self.analyze( - k, - df_aliased_col_name_to_real_col_name, - ): self.analyze( - v, - df_aliased_col_name_to_real_col_name, + self.analyze(k, df_aliased_col_name_to_real_col_name): self.analyze( + v, df_aliased_col_name_to_real_col_name ) for k, v in expr.assignments.items() }, @@ -553,10 +531,7 @@ def analyze( if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( - self.analyze( - expr.condition, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(expr.condition, df_aliased_col_name_to_real_col_name) if expr.condition else None ) @@ -642,10 +617,7 @@ def table_function_expression_extractor( "NamedArgumentsTableFunction, GeneratorTableFunction, or FlattenFunction." ) partition_spec_sql = ( - self.analyze( - expr.partition_spec, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(expr.partition_spec, df_aliased_col_name_to_real_col_name) if expr.partition_spec else "" ) @@ -711,7 +683,9 @@ def binary_operator_extractor( expr.left, df_aliased_col_name_to_real_col_name, parse_local_name ) right_sql_expr = self.to_sql_try_avoid_cast( - expr.right, df_aliased_col_name_to_real_col_name, parse_local_name + expr.right, + df_aliased_col_name_to_real_col_name, + parse_local_name, ) else: left_sql_expr = self.analyze( @@ -782,16 +756,11 @@ def to_sql_try_avoid_cast( expr, df_aliased_col_name_to_real_col_name, parse_local_name ) - def resolve( - self, - logical_plan: LogicalPlan, - ) -> SnowflakePlan: + def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: self.subquery_plans = [] self.generated_alias_maps = {} - result = self.do_resolve( - logical_plan, - ) + result = self.do_resolve(logical_plan) result.add_aliases(self.generated_alias_maps) @@ -800,19 +769,14 @@ def resolve( return result - def do_resolve( - self, - logical_plan: LogicalPlan, - ) -> SnowflakePlan: + def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: resolved_children = {} df_aliased_col_name_to_real_col_name: DefaultDict[ str, Dict[str, str] ] = defaultdict(dict) for c in logical_plan.children: # post-order traversal of the tree - resolved = self.resolve( - c, - ) + resolved = self.resolve(c) df_aliased_col_name_to_real_col_name.update( resolved.df_aliased_col_name_to_real_col_name ) @@ -840,9 +804,7 @@ def do_resolve( self.alias_maps_to_use = use_maps res = self.do_resolve_with_resolved_children( - logical_plan, - resolved_children, - df_aliased_col_name_to_real_col_name, + logical_plan, resolved_children, df_aliased_col_name_to_real_col_name ) res.df_aliased_col_name_to_real_col_name.update( df_aliased_col_name_to_real_col_name @@ -861,8 +823,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableFunctionJoin): return self.plan_builder.join_table_function( self.analyze( - logical_plan.table_function, - df_aliased_col_name_to_real_col_name, + logical_plan.table_function, df_aliased_col_name_to_real_col_name ), resolved_children[logical_plan.children[0]], logical_plan, @@ -874,8 +835,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, TableFunctionRelation): return self.plan_builder.from_table_function( self.analyze( - logical_plan.table_function, - df_aliased_col_name_to_real_col_name, + logical_plan.table_function, df_aliased_col_name_to_real_col_name ), logical_plan, ) @@ -883,8 +843,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Lateral): return self.plan_builder.lateral( self.analyze( - logical_plan.table_function, - df_aliased_col_name_to_real_col_name, + logical_plan.table_function, df_aliased_col_name_to_real_col_name ), resolved_children[logical_plan.children[0]], logical_plan, @@ -894,16 +853,12 @@ def do_resolve_with_resolved_children( return self.plan_builder.aggregate( [ self.to_sql_try_avoid_cast( - expr, - df_aliased_col_name_to_real_col_name, + expr, df_aliased_col_name_to_real_col_name ) for expr in logical_plan.grouping_expressions ], [ - self.analyze( - expr, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(expr, df_aliased_col_name_to_real_col_name) for expr in logical_plan.aggregate_expressions ], resolved_children[logical_plan.child], @@ -914,10 +869,7 @@ def do_resolve_with_resolved_children( return self.plan_builder.project( list( map( - lambda x: self.analyze( - x, - df_aliased_col_name_to_real_col_name, - ), + lambda x: self.analyze(x, df_aliased_col_name_to_real_col_name), logical_plan.project_list, ) ), @@ -928,8 +880,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Filter): return self.plan_builder.filter( self.analyze( - logical_plan.condition, - df_aliased_col_name_to_real_col_name, + logical_plan.condition, df_aliased_col_name_to_real_col_name ), resolved_children[logical_plan.child], logical_plan, @@ -947,16 +898,14 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Join): join_condition = ( self.analyze( - logical_plan.join_condition, - df_aliased_col_name_to_real_col_name, + logical_plan.join_condition, df_aliased_col_name_to_real_col_name ) if logical_plan.join_condition else "" ) match_condition = ( self.analyze( - logical_plan.match_condition, - df_aliased_col_name_to_real_col_name, + logical_plan.match_condition, df_aliased_col_name_to_real_col_name ) if logical_plan.match_condition else "" @@ -974,10 +923,7 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan, Sort): return self.plan_builder.sort( [ - self.analyze( - x, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(x, df_aliased_col_name_to_real_col_name) for x in logical_plan.order ], resolved_children[logical_plan.child], @@ -1041,10 +987,7 @@ def do_resolve_with_resolved_children( mode=logical_plan.mode, table_type=logical_plan.table_type, clustering_keys=[ - self.analyze( - x, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(x, df_aliased_col_name_to_real_col_name) for x in logical_plan.clustering_exprs ], comment=logical_plan.comment, @@ -1067,12 +1010,10 @@ def do_resolve_with_resolved_children( ) and isinstance(logical_plan.child.source_plan, Sort) return self.plan_builder.limit( self.to_sql_try_avoid_cast( - logical_plan.limit_expr, - df_aliased_col_name_to_real_col_name, + logical_plan.limit_expr, df_aliased_col_name_to_real_col_name ), self.to_sql_try_avoid_cast( - logical_plan.offset_expr, - df_aliased_col_name_to_real_col_name, + logical_plan.offset_expr, df_aliased_col_name_to_real_col_name ), resolved_children[logical_plan.child], on_top_of_order_by, @@ -1100,10 +1041,7 @@ def do_resolve_with_resolved_children( ] child = self.plan_builder.project( [ - self.analyze( - col, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(col, df_aliased_col_name_to_real_col_name) for col in project_exprs ], resolved_children[logical_plan.child], @@ -1119,33 +1057,26 @@ def do_resolve_with_resolved_children( if isinstance(logical_plan.pivot_values, List): pivot_values = [ - self.analyze( - pv, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(pv, df_aliased_col_name_to_real_col_name) for pv in logical_plan.pivot_values ] elif isinstance(logical_plan.pivot_values, ScalarSubquery): pivot_values = self.analyze( - logical_plan.pivot_values, - df_aliased_col_name_to_real_col_name, + logical_plan.pivot_values, df_aliased_col_name_to_real_col_name ) else: pivot_values = None pivot_plan = self.plan_builder.pivot( self.analyze( - logical_plan.pivot_column, - df_aliased_col_name_to_real_col_name, + logical_plan.pivot_column, df_aliased_col_name_to_real_col_name ), pivot_values, self.analyze( - logical_plan.aggregates[0], - df_aliased_col_name_to_real_col_name, + logical_plan.aggregates[0], df_aliased_col_name_to_real_col_name ), self.analyze( - logical_plan.default_on_null, - df_aliased_col_name_to_real_col_name, + logical_plan.default_on_null, df_aliased_col_name_to_real_col_name ) if logical_plan.default_on_null else None, @@ -1170,10 +1101,7 @@ def do_resolve_with_resolved_children( logical_plan.value_column, logical_plan.name_column, [ - self.analyze( - c, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(c, df_aliased_col_name_to_real_col_name) for c in logical_plan.column_list ], resolved_children[logical_plan.child], @@ -1215,10 +1143,7 @@ def do_resolve_with_resolved_children( refresh_mode=logical_plan.refresh_mode, initialize=logical_plan.initialize, clustering_keys=[ - self.analyze( - x, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(x, df_aliased_col_name_to_real_col_name) for x in logical_plan.clustering_exprs ], is_transient=logical_plan.is_transient, @@ -1250,10 +1175,7 @@ def do_resolve_with_resolved_children( validation_mode=logical_plan.validation_mode, column_names=logical_plan.column_names, transformations=[ - self.analyze( - x, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(x, df_aliased_col_name_to_real_col_name) for x in logical_plan.transformations ] if logical_plan.transformations @@ -1269,8 +1191,7 @@ def do_resolve_with_resolved_children( stage_location=logical_plan.stage_location, source_plan=logical_plan, partition_by=self.analyze( - logical_plan.partition_by, - df_aliased_col_name_to_real_col_name, + logical_plan.partition_by, df_aliased_col_name_to_real_col_name ) if logical_plan.partition_by else None, @@ -1285,18 +1206,13 @@ def do_resolve_with_resolved_children( return self.plan_builder.update( logical_plan.table_name, { - self.analyze( - k, - df_aliased_col_name_to_real_col_name, - ): self.analyze( - v, - df_aliased_col_name_to_real_col_name, + self.analyze(k, df_aliased_col_name_to_real_col_name): self.analyze( + v, df_aliased_col_name_to_real_col_name ) for k, v in logical_plan.assignments.items() }, self.analyze( - logical_plan.condition, - df_aliased_col_name_to_real_col_name, + logical_plan.condition, df_aliased_col_name_to_real_col_name ) if logical_plan.condition else None, @@ -1310,8 +1226,7 @@ def do_resolve_with_resolved_children( return self.plan_builder.delete( logical_plan.table_name, self.analyze( - logical_plan.condition, - df_aliased_col_name_to_real_col_name, + logical_plan.condition, df_aliased_col_name_to_real_col_name ) if logical_plan.condition else None, @@ -1329,14 +1244,10 @@ def do_resolve_with_resolved_children( if logical_plan.source else logical_plan.source, self.analyze( - logical_plan.join_expr, - df_aliased_col_name_to_real_col_name, + logical_plan.join_expr, df_aliased_col_name_to_real_col_name ), [ - self.analyze( - c, - df_aliased_col_name_to_real_col_name, - ) + self.analyze(c, df_aliased_col_name_to_real_col_name) for c in logical_plan.clauses ], logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index bbda752e33d..9750deba3f9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -312,15 +312,13 @@ def children_plan_nodes(self) -> List[Union["Selectable", "SnowflakePlan"]]: else: return [] - def replace_repeated_subquery_with_cte( - self, - ) -> "SnowflakePlan": + def replace_repeated_subquery_with_cte(self) -> "SnowflakePlan": # parameter protection # the common subquery elimination will be applied if cte_optimization is not enabled # and the new compilation stage is not enabled. When new compilation stage is enabled, # the common subquery elimination will be done through the new plan transformation. if ( - not self.session.cte_optimization_enabled + not self.session._cte_optimization_enabled or self.session._query_compilation_stage_enabled ): return self @@ -568,7 +566,7 @@ def build( placeholder_query = ( sql_generator(select_child._id) - if self.session.cte_optimization_enabled and select_child._id is not None + if self.session._cte_optimization_enabled and select_child._id is not None else None ) @@ -607,7 +605,7 @@ def build_binary( placeholder_query = ( sql_generator(select_left._id, select_right._id) - if self.session.cte_optimization_enabled + if self.session._cte_optimization_enabled and select_left._id is not None and select_right._id is not None else None diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 38a1bf3acd8..e8ce38f0edb 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -65,16 +65,16 @@ def should_start_query_compilation(self) -> bool: return ( not isinstance(current_session._conn, MockServerConnection) and (self._plan.source_plan is not None) - and self.current_session._query_compilation_stage_enabled + and current_session._query_compilation_stage_enabled and ( - self.current_session.cte_optimization_enabled - or self.current_session.large_query_breakdown_enabled + current_session.cte_optimization_enabled + or current_session.large_query_breakdown_enabled ) ) def compile(self) -> Dict[PlanQueryType, List[Query]]: - session = self._plan.session if self.should_start_query_compilation(): + session = self._plan.session # preparation for compilation # 1. make a copy of the original plan start_time = time.time() diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index 187629832e6..2cde864c062 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -105,9 +105,6 @@ def generate_queries( PlanQueryType.POST_ACTIONS: post_actions, } - def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan: - return super().resolve(logical_plan) - def do_resolve_with_resolved_children( self, logical_plan: LogicalPlan, @@ -207,9 +204,7 @@ def do_resolve_with_resolved_children( copied_resolved_child.queries = final_queries[PlanQueryType.QUERIES] resolved_children[logical_plan.children[0]] = copied_resolved_child resolved_plan = super().do_resolve_with_resolved_children( - logical_plan, - resolved_children, - df_aliased_col_name_to_real_col_name, + logical_plan, resolved_children, df_aliased_col_name_to_real_col_name ) elif isinstance(logical_plan, Selectable): @@ -233,9 +228,7 @@ def do_resolve_with_resolved_children( else: resolved_plan = super().do_resolve_with_resolved_children( - logical_plan, - resolved_children, - df_aliased_col_name_to_real_col_name, + logical_plan, resolved_children, df_aliased_col_name_to_real_col_name ) resolved_plan._is_valid_for_replacement = True diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index bf52ca89b79..60c2952f303 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -40,7 +40,7 @@ def create_query_generator( - plan: SnowflakePlan, + plan: SnowflakePlan ) -> QueryGenerator: """ Helper function to construct the query generator for a given valid SnowflakePlan. diff --git a/src/snowflake/snowpark/mock/_analyzer.py b/src/snowflake/snowpark/mock/_analyzer.py index e4cd33caf15..666654917ea 100644 --- a/src/snowflake/snowpark/mock/_analyzer.py +++ b/src/snowflake/snowpark/mock/_analyzer.py @@ -198,8 +198,7 @@ def analyze( expr.collation_spec.upper() if parse_local_name else expr.collation_spec ) return collate_expression( - self.analyze(expr.expr, expr_to_alias, parse_local_name), - collation_spec, + self.analyze(expr.expr, expr_to_alias, parse_local_name), collation_spec ) if isinstance(expr, (SubfieldString, SubfieldInt)): @@ -207,8 +206,7 @@ def analyze( if parse_local_name and isinstance(field, str): field = field.upper() return subfield_expression( - self.analyze(expr.expr, expr_to_alias, parse_local_name), - field, + self.analyze(expr.expr, expr_to_alias, parse_local_name), field ) if isinstance(expr, CaseWhen): @@ -230,11 +228,15 @@ def analyze( for expression in expr.expressions: if self.session.eliminate_numeric_sql_value_cast_enabled: resolved_expr = self.to_sql_try_avoid_cast( - expression, expr_to_alias, parse_local_name + expression, + expr_to_alias, + parse_local_name, ) else: resolved_expr = self.analyze( - expression, expr_to_alias, parse_local_name + expression, + expr_to_alias, + parse_local_name, ) block_expressions.append(resolved_expr) @@ -245,10 +247,16 @@ def analyze( for expression in expr.values: if self.session.eliminate_numeric_sql_value_cast_enabled: in_value = self.to_sql_try_avoid_cast( - expression, expr_to_alias, parse_local_name + expression, + expr_to_alias, + parse_local_name, ) else: - in_value = self.analyze(expression, expr_to_alias, parse_local_name) + in_value = self.analyze( + expression, + expr_to_alias, + parse_local_name, + ) in_values.append(in_value) return in_expression( @@ -277,17 +285,11 @@ def analyze( if isinstance(expr, WindowSpecDefinition): return window_spec_expression( [ - self.analyze( - x, - parse_local_name=parse_local_name, - ) + self.analyze(x, parse_local_name=parse_local_name) for x in expr.partition_spec ], [ - self.analyze( - x, - parse_local_name=parse_local_name, - ) + self.analyze(x, parse_local_name=parse_local_name) for x in expr.order_spec ], self.analyze( @@ -299,18 +301,8 @@ def analyze( if isinstance(expr, SpecifiedWindowFrame): return specified_window_frame_expression( expr.frame_type.sql, - self.window_frame_boundary( - self.to_sql_try_avoid_cast( - expr.lower, - {}, - ) - ), - self.window_frame_boundary( - self.to_sql_try_avoid_cast( - expr.upper, - {}, - ) - ), + self.window_frame_boundary(self.to_sql_try_avoid_cast(expr.lower, {})), + self.window_frame_boundary(self.to_sql_try_avoid_cast(expr.upper, {})), ) if isinstance(expr, UnspecifiedFrame): @@ -340,10 +332,7 @@ def analyze( children = [] for c in expr.children: - extracted = self.to_sql_try_avoid_cast( - c, - expr_to_alias, - ) + extracted = self.to_sql_try_avoid_cast(c, expr_to_alias) if isinstance(extracted, list): children.extend(extracted) else: @@ -359,13 +348,7 @@ def analyze( if not expr.expressions: return "*" else: - return [ - self.analyze( - e, - expr_to_alias, - ) - for e in expr.expressions - ] + return [self.analyze(e, expr_to_alias) for e in expr.expressions] if isinstance(expr, SnowflakeUDF): if expr.api_call_source is not None: @@ -383,10 +366,7 @@ def analyze( ) if isinstance(expr, TableFunctionExpression): - return self.table_function_expression_extractor( - expr, - expr_to_alias, - ) + return self.table_function_expression_extractor(expr, expr_to_alias) if isinstance(expr, TableFunctionPartitionSpecDefinition): return table_function_partition_spec( @@ -427,67 +407,35 @@ def analyze( if isinstance(expr, WithinGroup): return within_group_expression( self.analyze(expr.expr, expr_to_alias, parse_local_name), - [ - self.analyze( - e, - expr_to_alias, - ) - for e in expr.order_by_cols - ], + [self.analyze(e, expr_to_alias) for e in expr.order_by_cols], ) if isinstance(expr, BinaryExpression): - return self.binary_operator_extractor(expr, expr_to_alias, parse_local_name) + return self.binary_operator_extractor( + expr, + expr_to_alias, + parse_local_name, + ) if isinstance(expr, InsertMergeExpression): return insert_merge_statement( - self.analyze( - expr.condition, - expr_to_alias, - ) - if expr.condition - else None, - [ - self.analyze( - k, - expr_to_alias, - ) - for k in expr.keys - ], - [ - self.analyze( - v, - expr_to_alias, - ) - for v in expr.values - ], + self.analyze(expr.condition, expr_to_alias) if expr.condition else None, + [self.analyze(k, expr_to_alias) for k in expr.keys], + [self.analyze(v, expr_to_alias) for v in expr.values], ) if isinstance(expr, UpdateMergeExpression): return update_merge_statement( - self.analyze( - expr.condition, - expr_to_alias, - ) - if expr.condition - else None, + self.analyze(expr.condition, expr_to_alias) if expr.condition else None, { - self.analyze(k, expr_to_alias,): self.analyze( - v, - expr_to_alias, - ) + self.analyze(k, expr_to_alias): self.analyze(v, expr_to_alias) for k, v in expr.assignments.items() }, ) if isinstance(expr, DeleteMergeExpression): return delete_merge_statement( - self.analyze( - expr.condition, - expr_to_alias, - ) - if expr.condition - else None + self.analyze(expr.condition, expr_to_alias) if expr.condition else None ) if isinstance(expr, ListAgg): @@ -554,10 +502,7 @@ def table_function_expression_extractor( "NamedArgumentsTableFunction, GeneratorTableFunction, or FlattenFunction." ) partition_spec_sql = ( - self.analyze( - expr.partition_spec, - expr_to_alias, - ) + self.analyze(expr.partition_spec, expr_to_alias) if expr.partition_spec else "" ) @@ -578,8 +523,7 @@ def unary_expression_extractor( if v == expr.child.name: expr_to_alias[k] = quoted_name alias_exp = alias_expression( - self.analyze(expr.child, expr_to_alias, parse_local_name), - quoted_name, + self.analyze(expr.child, expr_to_alias, parse_local_name), quoted_name ) expr_str = alias_exp if keep_alias else expr.name or keep_alias @@ -614,7 +558,9 @@ def binary_operator_extractor( expr.left, expr_to_alias, parse_local_name ) right_sql_expr = self.to_sql_try_avoid_cast( - expr.right, expr_to_alias, parse_local_name + expr.right, + expr_to_alias, + parse_local_name, ) else: left_sql_expr = self.analyze(expr.left, expr_to_alias, parse_local_name) @@ -638,9 +584,7 @@ def binary_operator_extractor( ) def grouping_extractor( - self, - expr: GroupingSet, - expr_to_alias: Dict[str, str], + self, expr: GroupingSet, expr_to_alias: Dict[str, str] ) -> str: return self.analyze( FunctionExpression( @@ -659,10 +603,7 @@ def window_frame_boundary(self, offset: str) -> str: return offset def to_sql_try_avoid_cast( - self, - expr: Expression, - expr_to_alias: Dict[str, str], - parse_local_name=False, + self, expr: Expression, expr_to_alias: Dict[str, str], parse_local_name=False ) -> str: # if expression is a numeric literal, return the number without casting, # otherwise process as normal @@ -672,34 +613,23 @@ def to_sql_try_avoid_cast( return self.analyze(expr, expr_to_alias, parse_local_name) def resolve( - self, - logical_plan: LogicalPlan, - expr_to_alias: Optional[Dict[str, str]] = None, + self, logical_plan: LogicalPlan, expr_to_alias: Optional[Dict[str, str]] = None ) -> MockExecutionPlan: self.subquery_plans = [] if expr_to_alias is None: expr_to_alias = {} - - result = self.do_resolve( - logical_plan, - expr_to_alias, - ) + result = self.do_resolve(logical_plan, expr_to_alias) return result def do_resolve( - self, - logical_plan: LogicalPlan, - expr_to_alias: Dict[str, str], + self, logical_plan: LogicalPlan, expr_to_alias: Dict[str, str] ) -> MockExecutionPlan: resolved_children = {} expr_to_alias_maps = {} for c in logical_plan.children: _expr_to_alias = {} - resolved_children[c] = self.resolve( - c, - _expr_to_alias, - ) + resolved_children[c] = self.resolve(c, _expr_to_alias) expr_to_alias_maps[c] = _expr_to_alias # get counts of expr_to_alias keys @@ -713,9 +643,7 @@ def do_resolve( expr_to_alias.update({p: q for p, q in v.items() if counts[p] < 2}) return self.do_resolve_with_resolved_children( - logical_plan, - resolved_children, - expr_to_alias, + logical_plan, resolved_children, expr_to_alias ) def do_resolve_with_resolved_children( @@ -810,14 +738,8 @@ def do_resolve_with_resolved_children( logical_plan.child, SnowflakePlan ) and isinstance(logical_plan.child.source_plan, Sort) return self.plan_builder.limit( - self.to_sql_try_avoid_cast( - logical_plan.limit_expr, - expr_to_alias, - ), - self.to_sql_try_avoid_cast( - logical_plan.offset_expr, - expr_to_alias, - ), + self.to_sql_try_avoid_cast(logical_plan.limit_expr, expr_to_alias), + self.to_sql_try_avoid_cast(logical_plan.offset_expr, expr_to_alias), resolved_children[logical_plan.child], on_top_of_order_by, logical_plan, @@ -846,10 +768,7 @@ def do_resolve_with_resolved_children( query=resolved_children[logical_plan.child], stage_location=logical_plan.stage_location, source_plan=logical_plan, - partition_by=self.analyze( - logical_plan.partition_by, - expr_to_alias, - ) + partition_by=self.analyze(logical_plan.partition_by, expr_to_alias) if logical_plan.partition_by else None, file_format_name=logical_plan.file_format_name, From 3eade1a2e1f848387315eff8f1bd40ceb0abf816 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Mon, 30 Sep 2024 10:54:43 -0700 Subject: [PATCH 53/62] min-diff --- src/snowflake/snowpark/_internal/compiler/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index 60c2952f303..82c2b090487 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -39,9 +39,7 @@ TreeNode = Union[SnowflakePlan, Selectable] -def create_query_generator( - plan: SnowflakePlan -) -> QueryGenerator: +def create_query_generator(plan: SnowflakePlan) -> QueryGenerator: """ Helper function to construct the query generator for a given valid SnowflakePlan. """ From 1fa6ad2b6be696c84d51bdb99457426c549a6a3e Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 2 Oct 2024 15:26:50 -0700 Subject: [PATCH 54/62] add warnings --- .../_internal/compiler/plan_compiler.py | 28 +++++++++++-------- src/snowflake/snowpark/session.py | 6 ++++ tests/integ/test_multithreading.py | 18 ++++++++++++ 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 2ec32343195..441bcbe6521 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -49,6 +49,13 @@ class PlanCompiler: def __init__(self, plan: SnowflakePlan) -> None: self._plan = plan + session = plan.session + self.cte_optimization_enabled = session.cte_optimization_enabled + self.large_query_breakdown_enabled = session.large_query_breakdown_enabled + self.large_query_breakdown_complexity_bounds = ( + session.large_query_breakdown_complexity_bounds + ) + self.query_compilation_stage_enabled = session._query_compilation_stage_enabled def should_start_query_compilation(self) -> bool: """ @@ -68,11 +75,8 @@ def should_start_query_compilation(self) -> bool: return ( not isinstance(current_session._conn, MockServerConnection) and (self._plan.source_plan is not None) - and current_session._query_compilation_stage_enabled - and ( - current_session.cte_optimization_enabled - or current_session.large_query_breakdown_enabled - ) + and self.query_compilation_stage_enabled + and (self.cte_optimization_enabled or self.large_query_breakdown_enabled) ) def compile(self) -> Dict[PlanQueryType, List[Query]]: @@ -94,7 +98,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: # 3. apply each optimizations if needed # CTE optimization cte_start_time = time.time() - if session.cte_optimization_enabled: + if self.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) @@ -109,12 +113,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}") # Large query breakdown - if session.large_query_breakdown_enabled: + if self.large_query_breakdown_enabled: large_query_breakdown = LargeQueryBreakdown( - self._plan.session, + session, query_generator, logical_plans, - session.large_query_breakdown_complexity_bounds, + self.large_query_breakdown_complexity_bounds, ) logical_plans = large_query_breakdown.apply() @@ -135,9 +139,9 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time total_time = time.time() - start_time summary_value = { - TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.cte_optimization_enabled, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.large_query_breakdown_enabled, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self.large_query_breakdown_complexity_bounds, CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1e581d7f2a1..54a01532847 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -796,6 +796,12 @@ def sql_simplifier_enabled(self, value: bool) -> None: @cte_optimization_enabled.setter @experimental_parameter(version="1.15.0") def cte_optimization_enabled(self, value: bool) -> None: + if threading.active_count() > 1: + # TODO (SNOW-1541096): Remove the limitation once old cte implementation is removed. + _logger.warning( + "Setting cte_optimization_enabled is not currently thread-safe. Ignoring the update" + ) + return with self._lock: if value: self._conn._telemetry_client.send_cte_optimization_telemetry( diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index ea3ae1730e9..39caae0b0fc 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -3,6 +3,7 @@ # import hashlib +import logging import os import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed @@ -502,3 +503,20 @@ def finish(self): with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): executor.submit(register_and_test_udaf, session, i) + + +def test_concurrent_update_on_cte_optimization_enabled(session, caplog): + def run_cte_optimization(session_, thread_id): + if thread_id % 2 == 0: + session_.cte_optimization_enabled = True + else: + session_.cte_optimization_enabled = False + + caplog.clear() + with caplog.at_level(logging.WARNING): + with ThreadPoolExecutor(max_workers=5) as executor: + for i in range(5): + executor.submit(run_cte_optimization, session, i) + assert ( + "Setting cte_optimization_enabled is not currently thread-safe" in caplog.text + ) From f9948424382b107819241c4f465a5ae8d514d9e7 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 3 Oct 2024 12:00:43 -0700 Subject: [PATCH 55/62] address feedback --- src/snowflake/snowpark/session.py | 32 ++++++++++++++++++++++++++++ tests/integ/test_multithreading.py | 34 ++++++++++++++++++++---------- 2 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 54a01532847..ffebf943f55 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -781,6 +781,13 @@ def custom_package_usage_config(self) -> Dict: @sql_simplifier_enabled.setter def sql_simplifier_enabled(self, value: bool) -> None: + if threading.active_count() > 1: + _logger.warning( + "Setting sql_simplifier_enabled is not currently thread-safe. " + "Ignoring the update" + ) + return + with self._lock: self._conn._telemetry_client.send_sql_simplifier_telemetry( self._session_id, value @@ -802,6 +809,7 @@ def cte_optimization_enabled(self, value: bool) -> None: "Setting cte_optimization_enabled is not currently thread-safe. Ignoring the update" ) return + with self._lock: if value: self._conn._telemetry_client.send_cte_optimization_telemetry( @@ -813,6 +821,12 @@ def cte_optimization_enabled(self, value: bool) -> None: @experimental_parameter(version="1.20.0") def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: """Set the value for eliminate_numeric_sql_value_cast_enabled""" + if threading.active_count() > 1: + _logger.warning( + "Setting eliminate_numeric_sql_value_cast_enabled is not currently thread-safe. " + "Ignoring the update" + ) + return if value in [True, False]: with self._lock: @@ -829,6 +843,13 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: @experimental_parameter(version="1.21.0") def auto_clean_up_temp_table_enabled(self, value: bool) -> None: """Set the value for auto_clean_up_temp_table_enabled""" + if threading.active_count() > 1: + _logger.warning( + "Setting auto_clean_up_temp_table_enabled is not currently thread-safe. " + "Ignoring the update" + ) + return + if value in [True, False]: self._conn._telemetry_client.send_auto_clean_up_temp_table_telemetry( self._session_id, value @@ -847,6 +868,11 @@ def large_query_breakdown_enabled(self, value: bool) -> None: materialize the partitions, and then combine them to execute the query to improve overall performance. """ + if threading.active_count() > 1: + _logger.warning( + "Setting large_query_breakdown_enabled is not currently thread-safe. Ignoring the update" + ) + return if value in [True, False]: with self._lock: @@ -862,6 +888,12 @@ def large_query_breakdown_enabled(self, value: bool) -> None: @large_query_breakdown_complexity_bounds.setter def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: """Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" + if threading.active_count() > 1: + _logger.warning( + "Setting large_query_breakdown_complexity_bounds is not currently thread-safe. " + "Ignoring the update" + ) + return if len(value) != 2: raise ValueError( diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 39caae0b0fc..38130ba60e1 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -505,18 +505,30 @@ def finish(self): executor.submit(register_and_test_udaf, session, i) -def test_concurrent_update_on_cte_optimization_enabled(session, caplog): - def run_cte_optimization(session_, thread_id): - if thread_id % 2 == 0: - session_.cte_optimization_enabled = True - else: - session_.cte_optimization_enabled = False +@pytest.mark.parametrize( + "config,value", + [ + ("cte_optimization_enabled", True), + ("sql_simplifier_enabled", True), + ("eliminate_numeric_sql_value_cast_enabled", True), + ("auto_clean_up_temp_table_enabled", True), + ("large_query_breakdown_enabled", True), + ("large_query_breakdown_complexity_bounds", (20, 30)), + ], +) +def test_concurrent_update_on_sensitive_configs(session, config, value, caplog): + def change_config_value(session_): + session_.conf.set(config, value) caplog.clear() + + # check everything works find outside multiple threads + with caplog.at_level(logging.WARNING): + change_config_value(session) + assert f"Setting {config} is not currently thread-safe" not in caplog.text + with caplog.at_level(logging.WARNING): with ThreadPoolExecutor(max_workers=5) as executor: - for i in range(5): - executor.submit(run_cte_optimization, session, i) - assert ( - "Setting cte_optimization_enabled is not currently thread-safe" in caplog.text - ) + for _ in range(5): + executor.submit(change_config_value, session) + assert f"Setting {config} is not currently thread-safe" in caplog.text From 4621836123404a1d5a8daac76adcb3b93a9a2fc2 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 3 Oct 2024 13:27:45 -0700 Subject: [PATCH 56/62] address feedback --- .../_internal/compiler/plan_compiler.py | 26 ++++------ src/snowflake/snowpark/_internal/utils.py | 9 ++++ src/snowflake/snowpark/session.py | 50 ++++++------------- tests/integ/test_multithreading.py | 6 --- 4 files changed, 35 insertions(+), 56 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 441bcbe6521..8bb3a8c4557 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -49,13 +49,6 @@ class PlanCompiler: def __init__(self, plan: SnowflakePlan) -> None: self._plan = plan - session = plan.session - self.cte_optimization_enabled = session.cte_optimization_enabled - self.large_query_breakdown_enabled = session.large_query_breakdown_enabled - self.large_query_breakdown_complexity_bounds = ( - session.large_query_breakdown_complexity_bounds - ) - self.query_compilation_stage_enabled = session._query_compilation_stage_enabled def should_start_query_compilation(self) -> bool: """ @@ -75,8 +68,11 @@ def should_start_query_compilation(self) -> bool: return ( not isinstance(current_session._conn, MockServerConnection) and (self._plan.source_plan is not None) - and self.query_compilation_stage_enabled - and (self.cte_optimization_enabled or self.large_query_breakdown_enabled) + and current_session._query_compilation_stage_enabled + and ( + current_session.cte_optimization_enabled + or current_session.large_query_breakdown_enabled + ) ) def compile(self) -> Dict[PlanQueryType, List[Query]]: @@ -98,7 +94,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: # 3. apply each optimizations if needed # CTE optimization cte_start_time = time.time() - if self.cte_optimization_enabled: + if session.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) @@ -113,12 +109,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}") # Large query breakdown - if self.large_query_breakdown_enabled: + if session.large_query_breakdown_enabled: large_query_breakdown = LargeQueryBreakdown( session, query_generator, logical_plans, - self.large_query_breakdown_complexity_bounds, + session.large_query_breakdown_complexity_bounds, ) logical_plans = large_query_breakdown.apply() @@ -139,9 +135,9 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time total_time = time.time() - start_time summary_value = { - TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.cte_optimization_enabled, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.large_query_breakdown_enabled, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self.large_query_breakdown_complexity_bounds, + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index aff2c6d7ad1..359ea4f1693 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -17,6 +17,7 @@ import random import re import string +import threading import traceback import zipfile from enum import Enum @@ -297,6 +298,14 @@ def normalize_path(path: str, is_local: bool) -> str: return f"'{path}'" +def warn_session_config_update_in_multithreaded_mode(config): + if threading.active_count() > 1: + logger.warning( + f"Session configuration update for {config} in multithreaded mode is not thread-safe. " + "Please update the session configuration before starting the threads." + ) + + def normalize_remote_file_or_dir(name: str) -> str: return normalize_path(name, is_local=False) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index ffebf943f55..fb48dddf71b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -114,6 +114,7 @@ unwrap_single_quote, unwrap_stage_location_single_quote, validate_object_name, + warn_session_config_update_in_multithreaded_mode, warning, zip_file_or_directory_to_stream, ) @@ -781,12 +782,7 @@ def custom_package_usage_config(self) -> Dict: @sql_simplifier_enabled.setter def sql_simplifier_enabled(self, value: bool) -> None: - if threading.active_count() > 1: - _logger.warning( - "Setting sql_simplifier_enabled is not currently thread-safe. " - "Ignoring the update" - ) - return + warn_session_config_update_in_multithreaded_mode("sql_simplifier_enabled") with self._lock: self._conn._telemetry_client.send_sql_simplifier_telemetry( @@ -803,12 +799,7 @@ def sql_simplifier_enabled(self, value: bool) -> None: @cte_optimization_enabled.setter @experimental_parameter(version="1.15.0") def cte_optimization_enabled(self, value: bool) -> None: - if threading.active_count() > 1: - # TODO (SNOW-1541096): Remove the limitation once old cte implementation is removed. - _logger.warning( - "Setting cte_optimization_enabled is not currently thread-safe. Ignoring the update" - ) - return + warn_session_config_update_in_multithreaded_mode("cte_optimization_enabled") with self._lock: if value: @@ -821,12 +812,9 @@ def cte_optimization_enabled(self, value: bool) -> None: @experimental_parameter(version="1.20.0") def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: """Set the value for eliminate_numeric_sql_value_cast_enabled""" - if threading.active_count() > 1: - _logger.warning( - "Setting eliminate_numeric_sql_value_cast_enabled is not currently thread-safe. " - "Ignoring the update" - ) - return + warn_session_config_update_in_multithreaded_mode( + "eliminate_numeric_sql_value_cast_enabled" + ) if value in [True, False]: with self._lock: @@ -843,12 +831,9 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: @experimental_parameter(version="1.21.0") def auto_clean_up_temp_table_enabled(self, value: bool) -> None: """Set the value for auto_clean_up_temp_table_enabled""" - if threading.active_count() > 1: - _logger.warning( - "Setting auto_clean_up_temp_table_enabled is not currently thread-safe. " - "Ignoring the update" - ) - return + warn_session_config_update_in_multithreaded_mode( + "auto_clean_up_temp_table_enabled" + ) if value in [True, False]: self._conn._telemetry_client.send_auto_clean_up_temp_table_telemetry( @@ -868,11 +853,9 @@ def large_query_breakdown_enabled(self, value: bool) -> None: materialize the partitions, and then combine them to execute the query to improve overall performance. """ - if threading.active_count() > 1: - _logger.warning( - "Setting large_query_breakdown_enabled is not currently thread-safe. Ignoring the update" - ) - return + warn_session_config_update_in_multithreaded_mode( + "large_query_breakdown_enabled" + ) if value in [True, False]: with self._lock: @@ -888,12 +871,9 @@ def large_query_breakdown_enabled(self, value: bool) -> None: @large_query_breakdown_complexity_bounds.setter def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: """Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" - if threading.active_count() > 1: - _logger.warning( - "Setting large_query_breakdown_complexity_bounds is not currently thread-safe. " - "Ignoring the update" - ) - return + warn_session_config_update_in_multithreaded_mode( + "large_query_breakdown_complexity_bounds" + ) if len(value) != 2: raise ValueError( diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 38130ba60e1..8e5c543f695 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -521,12 +521,6 @@ def change_config_value(session_): session_.conf.set(config, value) caplog.clear() - - # check everything works find outside multiple threads - with caplog.at_level(logging.WARNING): - change_config_value(session) - assert f"Setting {config} is not currently thread-safe" not in caplog.text - with caplog.at_level(logging.WARNING): with ThreadPoolExecutor(max_workers=5) as executor: for _ in range(5): From e1c68f30c852528cf3cda56c9e091c3b19f2b96f Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 3 Oct 2024 14:40:32 -0700 Subject: [PATCH 57/62] fix string --- tests/integ/test_multithreading.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 8e5c543f695..73899160e5b 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -525,4 +525,7 @@ def change_config_value(session_): with ThreadPoolExecutor(max_workers=5) as executor: for _ in range(5): executor.submit(change_config_value, session) - assert f"Setting {config} is not currently thread-safe" in caplog.text + assert ( + f"Session configuration update for {config} in multithreaded mode is not thread-safe" + in caplog.text + ) From e5b48dd296cac644bcebabdc8f460c705c7616f2 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 3 Oct 2024 15:05:59 -0700 Subject: [PATCH 58/62] ignore on multi-thread --- src/snowflake/snowpark/_internal/utils.py | 4 +++- src/snowflake/snowpark/session.py | 26 ++++++++++++++--------- tests/integ/test_multithreading.py | 6 ++++++ 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 359ea4f1693..06998d18131 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -298,12 +298,14 @@ def normalize_path(path: str, is_local: bool) -> str: return f"'{path}'" -def warn_session_config_update_in_multithreaded_mode(config): +def warn_session_config_update_in_multithreaded_mode(config) -> bool: if threading.active_count() > 1: logger.warning( f"Session configuration update for {config} in multithreaded mode is not thread-safe. " "Please update the session configuration before starting the threads." ) + return True + return False def normalize_remote_file_or_dir(name: str) -> str: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index fb48dddf71b..f104f369c5f 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -782,7 +782,8 @@ def custom_package_usage_config(self) -> Dict: @sql_simplifier_enabled.setter def sql_simplifier_enabled(self, value: bool) -> None: - warn_session_config_update_in_multithreaded_mode("sql_simplifier_enabled") + if warn_session_config_update_in_multithreaded_mode("sql_simplifier_enabled"): + return with self._lock: self._conn._telemetry_client.send_sql_simplifier_telemetry( @@ -799,7 +800,8 @@ def sql_simplifier_enabled(self, value: bool) -> None: @cte_optimization_enabled.setter @experimental_parameter(version="1.15.0") def cte_optimization_enabled(self, value: bool) -> None: - warn_session_config_update_in_multithreaded_mode("cte_optimization_enabled") + if warn_session_config_update_in_multithreaded_mode("cte_optimization_enabled"): + return with self._lock: if value: @@ -812,9 +814,10 @@ def cte_optimization_enabled(self, value: bool) -> None: @experimental_parameter(version="1.20.0") def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: """Set the value for eliminate_numeric_sql_value_cast_enabled""" - warn_session_config_update_in_multithreaded_mode( + if warn_session_config_update_in_multithreaded_mode( "eliminate_numeric_sql_value_cast_enabled" - ) + ): + return if value in [True, False]: with self._lock: @@ -831,9 +834,10 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: @experimental_parameter(version="1.21.0") def auto_clean_up_temp_table_enabled(self, value: bool) -> None: """Set the value for auto_clean_up_temp_table_enabled""" - warn_session_config_update_in_multithreaded_mode( + if warn_session_config_update_in_multithreaded_mode( "auto_clean_up_temp_table_enabled" - ) + ): + return if value in [True, False]: self._conn._telemetry_client.send_auto_clean_up_temp_table_telemetry( @@ -853,9 +857,10 @@ def large_query_breakdown_enabled(self, value: bool) -> None: materialize the partitions, and then combine them to execute the query to improve overall performance. """ - warn_session_config_update_in_multithreaded_mode( + if warn_session_config_update_in_multithreaded_mode( "large_query_breakdown_enabled" - ) + ): + return if value in [True, False]: with self._lock: @@ -871,9 +876,10 @@ def large_query_breakdown_enabled(self, value: bool) -> None: @large_query_breakdown_complexity_bounds.setter def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: """Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" - warn_session_config_update_in_multithreaded_mode( + if warn_session_config_update_in_multithreaded_mode( "large_query_breakdown_complexity_bounds" - ) + ): + return if len(value) != 2: raise ValueError( diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 73899160e5b..b664a6515c9 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -521,6 +521,12 @@ def change_config_value(session_): session_.conf.set(config, value) caplog.clear() + change_config_value(session) + assert ( + f"Session configuration update for {config} in multithreaded mode is not thread-safe" + not in caplog.text + ) + with caplog.at_level(logging.WARNING): with ThreadPoolExecutor(max_workers=5) as executor: for _ in range(5): From 496e2beefbcf99a5a6c15ede6c91a6b9c59ec067 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 3 Oct 2024 15:56:43 -0700 Subject: [PATCH 59/62] undo ignore --- src/snowflake/snowpark/_internal/utils.py | 4 +--- src/snowflake/snowpark/session.py | 26 +++++++++-------------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 06998d18131..1c8decb677b 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -298,14 +298,12 @@ def normalize_path(path: str, is_local: bool) -> str: return f"'{path}'" -def warn_session_config_update_in_multithreaded_mode(config) -> bool: +def warn_session_config_update_in_multithreaded_mode(config) -> None: if threading.active_count() > 1: logger.warning( f"Session configuration update for {config} in multithreaded mode is not thread-safe. " "Please update the session configuration before starting the threads." ) - return True - return False def normalize_remote_file_or_dir(name: str) -> str: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index f104f369c5f..fb48dddf71b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -782,8 +782,7 @@ def custom_package_usage_config(self) -> Dict: @sql_simplifier_enabled.setter def sql_simplifier_enabled(self, value: bool) -> None: - if warn_session_config_update_in_multithreaded_mode("sql_simplifier_enabled"): - return + warn_session_config_update_in_multithreaded_mode("sql_simplifier_enabled") with self._lock: self._conn._telemetry_client.send_sql_simplifier_telemetry( @@ -800,8 +799,7 @@ def sql_simplifier_enabled(self, value: bool) -> None: @cte_optimization_enabled.setter @experimental_parameter(version="1.15.0") def cte_optimization_enabled(self, value: bool) -> None: - if warn_session_config_update_in_multithreaded_mode("cte_optimization_enabled"): - return + warn_session_config_update_in_multithreaded_mode("cte_optimization_enabled") with self._lock: if value: @@ -814,10 +812,9 @@ def cte_optimization_enabled(self, value: bool) -> None: @experimental_parameter(version="1.20.0") def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: """Set the value for eliminate_numeric_sql_value_cast_enabled""" - if warn_session_config_update_in_multithreaded_mode( + warn_session_config_update_in_multithreaded_mode( "eliminate_numeric_sql_value_cast_enabled" - ): - return + ) if value in [True, False]: with self._lock: @@ -834,10 +831,9 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: @experimental_parameter(version="1.21.0") def auto_clean_up_temp_table_enabled(self, value: bool) -> None: """Set the value for auto_clean_up_temp_table_enabled""" - if warn_session_config_update_in_multithreaded_mode( + warn_session_config_update_in_multithreaded_mode( "auto_clean_up_temp_table_enabled" - ): - return + ) if value in [True, False]: self._conn._telemetry_client.send_auto_clean_up_temp_table_telemetry( @@ -857,10 +853,9 @@ def large_query_breakdown_enabled(self, value: bool) -> None: materialize the partitions, and then combine them to execute the query to improve overall performance. """ - if warn_session_config_update_in_multithreaded_mode( + warn_session_config_update_in_multithreaded_mode( "large_query_breakdown_enabled" - ): - return + ) if value in [True, False]: with self._lock: @@ -876,10 +871,9 @@ def large_query_breakdown_enabled(self, value: bool) -> None: @large_query_breakdown_complexity_bounds.setter def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: """Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" - if warn_session_config_update_in_multithreaded_mode( + warn_session_config_update_in_multithreaded_mode( "large_query_breakdown_complexity_bounds" - ): - return + ) if len(value) != 2: raise ValueError( From 980d3b7388cf4b2cc5c18aaf7e2ca92147e25d4c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 3 Oct 2024 17:22:00 -0700 Subject: [PATCH 60/62] update warning message --- src/snowflake/snowpark/_internal/utils.py | 3 ++- tests/integ/test_multithreading.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 1c8decb677b..912e341f877 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -301,7 +301,8 @@ def normalize_path(path: str, is_local: bool) -> str: def warn_session_config_update_in_multithreaded_mode(config) -> None: if threading.active_count() > 1: logger.warning( - f"Session configuration update for {config} in multithreaded mode is not thread-safe. " + "You might have more than one threads sharing the Session object trying to update " + f"{config}. This is currently not thread-safe and may cause unexpected behavior. " "Please update the session configuration before starting the threads." ) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index b664a6515c9..f65a76dc79b 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -27,7 +27,7 @@ from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row -from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils +from tests.utils import IS_IN_STORED_PROC, IS_LINUX, IS_WINDOWS, TestFiles, Utils def test_concurrent_select_queries(session): @@ -505,6 +505,9 @@ def finish(self): executor.submit(register_and_test_udaf, session, i) +@pytest.mark.skipif( + IS_LINUX or IS_WINDOWS, reason="Linux and Windows behave badly for this test" +) @pytest.mark.parametrize( "config,value", [ @@ -523,7 +526,7 @@ def change_config_value(session_): caplog.clear() change_config_value(session) assert ( - f"Session configuration update for {config} in multithreaded mode is not thread-safe" + f"You might have more than one threads sharing the Session object trying to update {config}" not in caplog.text ) @@ -532,6 +535,6 @@ def change_config_value(session_): for _ in range(5): executor.submit(change_config_value, session) assert ( - f"Session configuration update for {config} in multithreaded mode is not thread-safe" + f"You might have more than one threads sharing the Session object trying to update {config}" in caplog.text ) From 54a6b5ddaa12b187ce4e83511fc324ce9aa0497b Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 4 Oct 2024 07:32:38 -0700 Subject: [PATCH 61/62] address comments --- src/snowflake/snowpark/_internal/utils.py | 4 ++-- src/snowflake/snowpark/session.py | 15 +++++---------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 912e341f877..8783faa39d6 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -302,8 +302,8 @@ def warn_session_config_update_in_multithreaded_mode(config) -> None: if threading.active_count() > 1: logger.warning( "You might have more than one threads sharing the Session object trying to update " - f"{config}. This is currently not thread-safe and may cause unexpected behavior. " - "Please update the session configuration before starting the threads." + f"{config}. Updating this while other tasks are running can potentially cause " + "unexpected behavior. Please update the session configuration before starting the threads." ) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index fb48dddf71b..3a4eacf25ae 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -545,6 +545,11 @@ def __init__( self._udtf_registration = UDTFRegistration(self) self._udaf_registration = UDAFRegistration(self) + self._plan_builder = ( + SnowflakePlanBuilder(self) + if isinstance(self._conn, ServerConnection) + else MockSnowflakePlanBuilder(self) + ) self._last_action_id = 0 self._last_canceled_id = 0 self._use_scoped_temp_objects: bool = ( @@ -639,16 +644,6 @@ def _analyzer(self) -> Analyzer: ) return self._thread_store.analyzer - @property - def _plan_builder(self): - if not hasattr(self._thread_store, "plan_builder"): - self._thread_store.plan_builder = ( - SnowflakePlanBuilder(self) - if isinstance(self._conn, ServerConnection) - else MockSnowflakePlanBuilder(self) - ) - return self._thread_store.plan_builder - def close(self) -> None: """Close this session.""" if is_in_stored_procedure(): From 67609e8395da0b599aae28af54718d48232ecc59 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 4 Oct 2024 09:53:07 -0700 Subject: [PATCH 62/62] address comments --- tests/integ/test_multithreading.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index f65a76dc79b..14b9c14578a 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -506,7 +506,8 @@ def finish(self): @pytest.mark.skipif( - IS_LINUX or IS_WINDOWS, reason="Linux and Windows behave badly for this test" + IS_LINUX or IS_WINDOWS, + reason="Linux and Windows test show multiple active threads when no threadpool is enabled", ) @pytest.mark.parametrize( "config,value",