From 56fb566a9fa89184e705263bde96c18d697697b4 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 11:44:26 -0700 Subject: [PATCH 01/20] init --- src/snowflake/snowpark/_internal/udf_utils.py | 16 ++++++++++------ src/snowflake/snowpark/session.py | 8 ++++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index b79fcdcf9c9..b4be8cca9c3 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -980,8 +980,11 @@ def add_snowpark_package_to_sproc_packages( if packages is None: if session is None: packages = [this_package] - elif package_name not in session._packages: - packages = list(session._packages.values()) + [this_package] + else: + with session._lock: + session_packages = session._packages.copy() + if package_name not in session_packages: + packages = list(session_packages.values()) + [this_package] else: package_names = [p if isinstance(p, str) else p.__name__ for p in packages] if not any(p.startswith(package_name) for p in package_names): @@ -1247,10 +1250,11 @@ def create_python_udf_or_sp( comment: Optional[str] = None, native_app_params: Optional[Dict[str, Any]] = None, ) -> None: - if session is not None and session._runtime_version_from_requirement: - runtime_version = session._runtime_version_from_requirement - else: - runtime_version = f"{sys.version_info[0]}.{sys.version_info[1]}" + with session._lock: + if session is not None and session._runtime_version_from_requirement: + runtime_version = session._runtime_version_from_requirement + else: + runtime_version = f"{sys.version_info[0]}.{sys.version_info[1]}" if replace and if_not_exists: raise ValueError("options replace and if_not_exists are incompatible") diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a04e381e985..5e6ad560592 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -13,6 +13,7 @@ import re import sys import tempfile +import threading import warnings from array import array from functools import reduce @@ -499,6 +500,8 @@ def __init__( if len(_active_sessions) >= 1 and is_in_stored_procedure(): raise SnowparkClientExceptionMessages.DONT_CREATE_SESSION_IN_SP() self._conn = conn + self._thread_store = threading.local() + self._lock = threading.RLock() self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} self._packages: Dict[str, str] = {} @@ -3007,8 +3010,9 @@ def get_fully_qualified_name_if_possible(self, name: str) -> str: """ Returns the fully qualified object name if current database/schema exists, otherwise returns the object name """ - database = self.get_current_database() - schema = self.get_current_schema() + with self._lock: + database = self.get_current_database() + schema = self.get_current_schema() if database and schema: return f"{database}.{schema}.{name}" From 66003d19e785d363d2a001362bacde23ab08a343 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 13:28:45 -0700 Subject: [PATCH 02/20] make udf/sproc related files thread-safe --- src/snowflake/snowpark/_internal/udf_utils.py | 6 +- src/snowflake/snowpark/session.py | 81 +++++++++++-------- 2 files changed, 49 insertions(+), 38 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index b4be8cca9c3..4012ad9ff6e 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1089,7 +1089,7 @@ def resolve_imports_and_packages( else session.get_session_stage(statement_params=statement_params) ) - if session: + all_urls = [] if imports: udf_level_imports = {} for udf_import in imports: @@ -1117,10 +1117,6 @@ def resolve_imports_and_packages( upload_and_import_stage, statement_params=statement_params, ) - else: - all_urls = [] - else: - all_urls = [] dest_prefix = get_udf_upload_prefix(udf_name) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 5e6ad560592..af8c06ee3a6 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -809,7 +809,8 @@ def get_imports(self) -> List[str]: Returns a list of imports added for user defined functions (UDFs). This list includes any Python or zip files that were added automatically by the library. """ - return list(self._import_paths.keys()) + with self._lock: + return list(self._import_paths.keys()) def add_import( self, @@ -890,7 +891,8 @@ def add_import( path, checksum, leading_path = self._resolve_import_path( path, import_path, chunk_size, whole_file_hash ) - self._import_paths[path] = (checksum, leading_path) + with self._lock: + self._import_paths[path] = (checksum, leading_path) def remove_import(self, path: str) -> None: """ @@ -917,10 +919,11 @@ def remove_import(self, path: str) -> None: if not trimmed_path.startswith(STAGE_PREFIX) else trimmed_path ) - if abs_path not in self._import_paths: - raise KeyError(f"{abs_path} is not found in the existing imports") - else: - self._import_paths.pop(abs_path) + with self._lock: + if abs_path not in self._import_paths: + raise KeyError(f"{abs_path} is not found in the existing imports") + else: + self._import_paths.pop(abs_path) def clear_imports(self) -> None: """ @@ -929,7 +932,8 @@ def clear_imports(self) -> None: if isinstance(self._conn, MockServerConnection): self.udf._clear_session_imports() self.sproc._clear_session_imports() - self._import_paths.clear() + with self._lock: + self._import_paths.clear() def _resolve_import_path( self, @@ -1020,7 +1024,8 @@ def _resolve_imports( upload_and_import_stage ) - import_paths = udf_level_import_paths or self._import_paths + with self._lock: + import_paths = udf_level_import_paths or self._import_paths.copy() for path, (prefix, leading_path) in import_paths.items(): # stage file if path.startswith(STAGE_PREFIX): @@ -1102,7 +1107,8 @@ def get_packages(self) -> Dict[str, str]: The key of this ``dict`` is the package name and the value of this ``dict`` is the corresponding requirement specifier. """ - return self._packages.copy() + with self._lock: + return self._packages.copy() def add_packages( self, *packages: Union[str, ModuleType, Iterable[Union[str, ModuleType]]] @@ -1193,16 +1199,18 @@ def remove_package(self, package: str) -> None: 0 """ package_name = pkg_resources.Requirement.parse(package).key - if package_name in self._packages: - self._packages.pop(package_name) - else: - raise ValueError(f"{package_name} is not in the package list") + with self._lock: + if package_name in self._packages: + self._packages.pop(package_name) + else: + raise ValueError(f"{package_name} is not in the package list") def clear_packages(self) -> None: """ Clears all third-party packages of a user-defined function (UDF). """ - self._packages.clear() + with self._lock: + self._packages.clear() def add_requirements(self, file_path: str) -> None: """ @@ -1550,7 +1558,8 @@ def _resolve_packages( if isinstance(self._conn, MockServerConnection): # in local testing we don't resolve the packages, we just return what is added errors = [] - result_dict = self._packages.copy() + with self._lock: + result_dict = self._packages.copy() for pkg_name, _, pkg_req in package_dict.values(): if pkg_name in result_dict and str(pkg_req) != result_dict[pkg_name]: errors.append( @@ -1566,7 +1575,8 @@ def _resolve_packages( elif len(errors) > 0: raise RuntimeError(errors) - self._packages.update(result_dict) + with self._lock: + self._packages.update(result_dict) return list(result_dict.values()) package_table = "information_schema.packages" @@ -1581,9 +1591,12 @@ def _resolve_packages( # 'python-dateutil': 'python-dateutil==2.8.2'} # Add to packages dictionary. Make a copy of existing packages # dictionary to avoid modifying it during intermediate steps. - result_dict = ( - existing_packages_dict.copy() if existing_packages_dict is not None else {} - ) + with self._lock: + result_dict = ( + existing_packages_dict.copy() + if existing_packages_dict is not None + else {} + ) # Retrieve list of dependencies that need to be added dependency_packages = self._get_dependency_packages( @@ -1616,8 +1629,9 @@ def _resolve_packages( if include_pandas: extra_modules.append("pandas") - if existing_packages_dict is not None: - existing_packages_dict.update(result_dict) + with self._lock: + if existing_packages_dict is not None: + existing_packages_dict.update(result_dict) return list(result_dict.values()) + self._get_req_identifiers_list( extra_modules, result_dict ) @@ -2300,18 +2314,19 @@ def get_session_stage( Therefore, if you switch database or schema during the session, the stage will not be re-created in the new database or schema, and still references the stage in the old database or schema. """ - if not self._session_stage: - full_qualified_stage_name = self.get_fully_qualified_name_if_possible( - random_name_for_temp_object(TempObjectType.STAGE) - ) - self._run_query( - f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \ - stage if not exists {full_qualified_stage_name}", - is_ddl_on_temp_object=True, - statement_params=statement_params, - ) - # set the value after running the query to ensure atomicity - self._session_stage = full_qualified_stage_name + with self._lock: + if not self._session_stage: + full_qualified_stage_name = self.get_fully_qualified_name_if_possible( + random_name_for_temp_object(TempObjectType.STAGE) + ) + self._run_query( + f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \ + stage if not exists {full_qualified_stage_name}", + is_ddl_on_temp_object=True, + statement_params=statement_params, + ) + # set the value after running the query to ensure atomicity + self._session_stage = full_qualified_stage_name return f"{STAGE_PREFIX}{self._session_stage}" def _write_modin_pandas_helper( From e75dde123dbcf33c8071236436495c0214be62fc Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 14:52:58 -0700 Subject: [PATCH 03/20] init --- .../snowpark/_internal/server_connection.py | 11 ++++++++++- src/snowflake/snowpark/session.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 56445fc31b2..0ab0ae490fa 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -8,6 +8,7 @@ import inspect import os import sys +import threading import time from logging import getLogger from typing import ( @@ -154,6 +155,8 @@ def __init__( options: Dict[str, Union[int, str]], conn: Optional[SnowflakeConnection] = None, ) -> None: + self._lock = threading.RLock() + self._thread_stored = threading.local() self._lower_case_parameters = {k.lower(): v for k, v in options.items()} self._add_application_parameters() self._conn = conn if conn else connect(**self._lower_case_parameters) @@ -170,8 +173,8 @@ def __init__( if "password" in self._lower_case_parameters: self._lower_case_parameters["password"] = None - self._cursor = self._conn.cursor() self._telemetry_client = TelemetryClient(self._conn) + # TODO: protect _query_listener self._query_listener: Set[QueryHistory] = set() # The session in this case refers to a Snowflake session, not a # Snowpark session @@ -183,6 +186,12 @@ def __init__( "_skip_upload_on_content_match" in signature.parameters ) + @property + def _cursor(self) -> SnowflakeCursor: + if not hasattr(self._thread_stored, "cursor"): + self._thread_stored.cursor = self._conn.cursor() + return self._thread_stored.cursor + def _add_application_parameters(self) -> None: if PARAM_APPLICATION not in self._lower_case_parameters: # Mirrored from snowflake-connector-python/src/snowflake/connector/connection.py#L295 diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 607cd047f2b..1c436360fb4 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -540,9 +540,6 @@ def __init__( ) self._file = FileOperation(self) self._lineage = Lineage(self) - self._analyzer = ( - Analyzer(self) if isinstance(conn, ServerConnection) else MockAnalyzer(self) - ) self._sql_simplifier_enabled: bool = ( self._conn._get_client_side_session_parameter( _PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING, True @@ -608,6 +605,16 @@ def _generate_new_action_id(self) -> int: self._last_action_id += 1 return self._last_action_id + @property + def _analyzer(self) -> Analyzer: + if not hasattr(self._thread_store, "analyzer"): + self._thread_store.analyzer = ( + Analyzer(self) + if isinstance(self._conn, ServerConnection) + else MockAnalyzer(self) + ) + return self._thread_store.analyzer + def close(self) -> None: """Close this session.""" if is_in_stored_procedure(): From 68a8c1c7186df5892a664bc297341a97e5461a3c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 15:22:53 -0700 Subject: [PATCH 04/20] make query listener thread-safe --- .../snowpark/_internal/server_connection.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 0ab0ae490fa..7d51f1189c4 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -174,7 +174,6 @@ def __init__( if "password" in self._lower_case_parameters: self._lower_case_parameters["password"] = None self._telemetry_client = TelemetryClient(self._conn) - # TODO: protect _query_listener self._query_listener: Set[QueryHistory] = set() # The session in this case refers to a Snowflake session, not a # Snowpark session @@ -219,10 +218,12 @@ def _add_application_parameters(self) -> None: ] = get_version() def add_query_listener(self, listener: QueryHistory) -> None: - self._query_listener.add(listener) + with self._lock: + self._query_listener.add(listener) def remove_query_listener(self, listener: QueryHistory) -> None: - self._query_listener.remove(listener) + with self._lock: + self._query_listener.remove(listener) def close(self) -> None: if self._conn: @@ -369,8 +370,9 @@ def upload_stream( raise ex def notify_query_listeners(self, query_record: QueryRecord) -> None: - for listener in self._query_listener: - listener._add_query(query_record) + with self._lock: + for listener in self._query_listener: + listener._add_query(query_record) def execute_and_notify_query_listener( self, query: str, **kwargs: Any From 31a57343e348115f4944cb94d49f02d0c58b3288 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 16:21:57 -0700 Subject: [PATCH 05/20] Fix query_tag and last_action_id --- src/snowflake/snowpark/session.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1c436360fb4..004f0c65361 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -602,8 +602,9 @@ def __str__(self): ) def _generate_new_action_id(self) -> int: - self._last_action_id += 1 - return self._last_action_id + with self._lock: + self._last_action_id += 1 + return self._last_action_id @property def _analyzer(self) -> Analyzer: @@ -805,7 +806,8 @@ def cancel_all(self) -> None: This does not affect any action methods called in the future. """ _logger.info("Canceling all running queries") - self._last_canceled_id = self._last_action_id + with self._lock: + self._last_canceled_id = self._last_action_id if not isinstance(self._conn, MockServerConnection): self._conn.run_query( f"select system$cancel_all_queries({self._session_id})" @@ -1910,11 +1912,12 @@ def query_tag(self) -> Optional[str]: @query_tag.setter def query_tag(self, tag: str) -> None: - if tag: - self._conn.run_query(f"alter session set query_tag = {str_to_sql(tag)}") - else: - self._conn.run_query("alter session unset query_tag") - self._query_tag = tag + with self._lock: + if tag: + self._conn.run_query(f"alter session set query_tag = {str_to_sql(tag)}") + else: + self._conn.run_query("alter session unset query_tag") + self._query_tag = tag def _get_remote_query_tag(self) -> None: """ From b4dadda2ad2392c6a5d0bbf898cbeb5139021b87 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 16:56:57 -0700 Subject: [PATCH 06/20] core updates done --- src/snowflake/snowpark/dataframe.py | 5 +++-- src/snowflake/snowpark/session.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 25ce987fd78..20cf07334f5 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -2027,12 +2027,13 @@ def _union_by_name_internal( ] names = right_project_list + not_found_attrs - if self._session.sql_simplifier_enabled and other._select_statement: + sql_simplifier_enabled = self._session.sql_simplifier_enabled + if sql_simplifier_enabled and other._select_statement: right_child = self._with_plan(other._select_statement.select(names)) else: right_child = self._with_plan(Project(names, other._plan)) - if self._session.sql_simplifier_enabled: + if sql_simplifier_enabled: df = self._with_plan( self._select_statement.set_operator( right_child._select_statement diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 004f0c65361..0a9fdd8a83d 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3020,8 +3020,9 @@ def get_fully_qualified_name_if_possible(self, name: str) -> str: """ Returns the fully qualified object name if current database/schema exists, otherwise returns the object name """ - database = self.get_current_database() - schema = self.get_current_schema() + with self._lock: + database = self.get_current_database() + schema = self.get_current_schema() if database and schema: return f"{database}.{schema}.{name}" From b8c64960456cdd77789b88b14378865c1c41de0e Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 11:00:11 -0700 Subject: [PATCH 07/20] Add tests --- src/snowflake/snowpark/session.py | 25 +++---- tests/integ/test_multithreading.py | 109 +++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 12 deletions(-) create mode 100644 tests/integ/test_multithreading.py diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 0a9fdd8a83d..3a099560bc4 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2310,18 +2310,19 @@ def get_session_stage( Therefore, if you switch database or schema during the session, the stage will not be re-created in the new database or schema, and still references the stage in the old database or schema. """ - if not self._session_stage: - full_qualified_stage_name = self.get_fully_qualified_name_if_possible( - random_name_for_temp_object(TempObjectType.STAGE) - ) - self._run_query( - f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \ - stage if not exists {full_qualified_stage_name}", - is_ddl_on_temp_object=True, - statement_params=statement_params, - ) - # set the value after running the query to ensure atomicity - self._session_stage = full_qualified_stage_name + with self._lock: + if not self._session_stage: + full_qualified_stage_name = self.get_fully_qualified_name_if_possible( + random_name_for_temp_object(TempObjectType.STAGE) + ) + self._run_query( + f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \ + stage if not exists {full_qualified_stage_name}", + is_ddl_on_temp_object=True, + statement_params=statement_params, + ) + # set the value after running the query to ensure atomicity + self._session_stage = full_qualified_stage_name return f"{STAGE_PREFIX}{self._session_stage}" def _write_modin_pandas_helper( diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py new file mode 100644 index 00000000000..d2a966217a6 --- /dev/null +++ b/tests/integ/test_multithreading.py @@ -0,0 +1,109 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from concurrent.futures import ThreadPoolExecutor, as_completed +from unittest.mock import patch + +import pytest + +from snowflake.snowpark.functions import lit +from snowflake.snowpark.row import Row +from tests.utils import IS_IN_STORED_PROC, Utils + + +def test_concurrent_select_queries(session): + def run_select(session_, thread_id): + df = session_.sql(f"SELECT {thread_id} as A") + assert df.collect()[0][0] == thread_id + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(run_select, session, i) + + +def test_concurrent_dataframe_operations(session): + try: + table_name = Utils.random_table_name() + data = [(i, 11 * i) for i in range(10)] + df = session.create_dataframe(data, ["A", "B"]) + df.write.save_as_table(table_name, table_type="temporary") + + def run_dataframe_operation(session_, thread_id): + df = session_.table(table_name) + df = df.filter(df.a == lit(thread_id)) + df = df.with_column("C", df.b + 100 * df.a) + df = df.rename(df.a, "D").limit(1) + return df + + dfs = [] + with ThreadPoolExecutor(max_workers=10) as executor: + df_futures = [ + executor.submit(run_dataframe_operation, session, i) for i in range(10) + ] + + for future in as_completed(df_futures): + dfs.append(future.result()) + + main_df = dfs[0] + for df in dfs[1:]: + main_df = main_df.union(df) + + Utils.check_answer( + main_df, [Row(D=i, B=11 * i, C=11 * i + 100 * i) for i in range(10)] + ) + + finally: + Utils.drop_table(session, table_name) + + +def test_query_listener(session): + def run_select(session_, thread_id): + session_.sql(f"SELECT {thread_id} as A").collect() + + with session.query_history() as history: + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(run_select, session, i) + + queries_sent = [query.sql_text for query in history.queries] + assert len(queries_sent) == 10 + for i in range(10): + assert f"SELECT {i} as A" in queries_sent + + +@pytest.mark.skipif( + IS_IN_STORED_PROC, reason="show parameters is not supported in stored procedure" +) +def test_query_tagging(session): + def set_query_tag(session_, thread_id): + session_.query_tag = f"tag_{thread_id}" + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(set_query_tag, session, i) + + actual_query_tag = session.sql("SHOW PARAMETERS LIKE 'QUERY_TAG'").collect()[0][1] + assert actual_query_tag == session.query_tag + + +def test_session_stage_created_once(session): + with patch.object( + session._conn, "run_query", wraps=session._conn.run_query + ) as patched_run_query: + with ThreadPoolExecutor(max_workers=10) as executor: + for _ in range(10): + executor.submit(session.get_session_stage) + + assert patched_run_query.call_count == 1 + + +def test_action_ids_are_unique(session): + with ThreadPoolExecutor(max_workers=10) as executor: + action_ids = set() + futures = [executor.submit(session._generate_new_action_id) for _ in range(10)] + + for future in as_completed(futures): + action_ids.add(future.result()) + + assert len(action_ids) == 10 From f39837e29ae79e1e26536cda2577349c45033017 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 12:01:12 -0700 Subject: [PATCH 08/20] Fix local tests --- tests/integ/test_multithreading.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index d2a966217a6..164a4b7b590 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -57,6 +57,11 @@ def run_dataframe_operation(session_, thread_id): Utils.drop_table(session, table_name) +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="SQL query and query listeners are not supported", + run=False, +) def test_query_listener(session): def run_select(session_, thread_id): session_.sql(f"SELECT {thread_id} as A").collect() @@ -72,6 +77,11 @@ def run_select(session_, thread_id): assert f"SELECT {i} as A" in queries_sent +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="Query tag is a SQL feature", + run=False, +) @pytest.mark.skipif( IS_IN_STORED_PROC, reason="show parameters is not supported in stored procedure" ) @@ -87,6 +97,11 @@ def set_query_tag(session_, thread_id): assert actual_query_tag == session.query_tag +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="SQL query is not supported", + run=False, +) def test_session_stage_created_once(session): with patch.object( session._conn, "run_query", wraps=session._conn.run_query From 37c041936331de5c97fe32702f3cae8aff83ccd0 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 15:06:48 -0700 Subject: [PATCH 09/20] add file IO tests --- tests/integ/conftest.py | 23 ++++++- .../integ/scala/test_file_operation_suite.py | 17 +---- tests/integ/test_multithreading.py | 69 ++++++++++++++++++- 3 files changed, 91 insertions(+), 18 deletions(-) diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index ec619605e66..319abb137f4 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -13,7 +13,13 @@ from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.mock._connection import MockServerConnection from tests.parameters import CONNECTION_PARAMETERS -from tests.utils import TEST_SCHEMA, Utils, running_on_jenkins, running_on_public_ci +from tests.utils import ( + TEST_SCHEMA, + TestFiles, + Utils, + running_on_jenkins, + running_on_public_ci, +) def print_help() -> None: @@ -235,3 +241,18 @@ def temp_schema(connection, session, local_testing_mode) -> None: ) yield temp_schema_name cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}") + + +@pytest.fixture(scope="module") +def temp_stage(session, resources_path, local_testing_mode): + tmp_stage_name = Utils.random_stage_name() + test_files = TestFiles(resources_path) + + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) + Utils.upload_to_stage( + session, tmp_stage_name, test_files.test_file_parquet, compress=False + ) + yield tmp_stage_name + if not local_testing_mode: + Utils.drop_stage(session, tmp_stage_name) diff --git a/tests/integ/scala/test_file_operation_suite.py b/tests/integ/scala/test_file_operation_suite.py index 2dc424dde09..82a1722a729 100644 --- a/tests/integ/scala/test_file_operation_suite.py +++ b/tests/integ/scala/test_file_operation_suite.py @@ -14,7 +14,7 @@ SnowparkSQLException, SnowparkUploadFileException, ) -from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, TestFiles, Utils +from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, Utils def random_alphanumeric_name(): @@ -74,21 +74,6 @@ def path4(temp_source_directory): yield filename -@pytest.fixture(scope="module") -def temp_stage(session, resources_path, local_testing_mode): - tmp_stage_name = Utils.random_stage_name() - test_files = TestFiles(resources_path) - - if not local_testing_mode: - Utils.create_stage(session, tmp_stage_name, is_temporary=True) - Utils.upload_to_stage( - session, tmp_stage_name, test_files.test_file_parquet, compress=False - ) - yield tmp_stage_name - if not local_testing_mode: - Utils.drop_stage(session, tmp_stage_name) - - def test_put_with_one_file( session, temp_stage, path1, path2, path3, local_testing_mode ): diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 164a4b7b590..10fcc6ef70d 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -2,6 +2,9 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import hashlib +import os +import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed from unittest.mock import patch @@ -9,7 +12,7 @@ from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row -from tests.utils import IS_IN_STORED_PROC, Utils +from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils def test_concurrent_select_queries(session): @@ -122,3 +125,67 @@ def test_action_ids_are_unique(session): action_ids.add(future.result()) assert len(action_ids) == 10 + + +@pytest.mark.parametrize("use_stream", [True, False]) +def test_file_io(session, resources_path, temp_stage, use_stream): + stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}" + stage_with_prefix = f"@{temp_stage}/{stage_prefix}/" + test_files = TestFiles(resources_path) + + resources_files = [ + test_files.test_file_csv, + test_files.test_file2_csv, + test_files.test_file_json, + test_files.test_file_csv_header, + test_files.test_file_csv_colon, + test_files.test_file_csv_quotes, + test_files.test_file_csv_special_format, + test_files.test_file_json_special_format, + test_files.test_file_csv_quotes_special, + test_files.test_concat_file1_csv, + test_files.test_concat_file2_csv, + ] + + def get_file_hash(fd): + return hashlib.md5(fd.read()).hexdigest() + + def put_and_get_file(upload_file_path, download_dir): + if use_stream: + with open(upload_file_path, "rb") as fd: + results = session.file.put_stream( + fd, stage_with_prefix, auto_compress=False, overwrite=False + ) + else: + results = session.file.put( + upload_file_path, + stage_with_prefix, + auto_compress=False, + overwrite=False, + ) + # assert file is uploaded successfully + assert len(results) == 1 + assert results[0].status == "UPLOADED" + + stage_file_name = f"{stage_with_prefix}{os.path.basename(upload_file_path)}" + if use_stream: + fd = session.file.get_stream(stage_file_name, download_dir) + with open(upload_file_path, "rb") as upload_fd: + assert get_file_hash(upload_fd) == get_file_hash(fd) + + else: + results = session.file.get(stage_file_name, download_dir) + # assert file is downloaded successfully + assert len(results) == 1 + assert results[0].status == "DOWNLOADED" + download_file_path = results[0].file + # assert two files are identical + with open(upload_file_path, "rb") as upload_fd, open( + download_file_path, "rb" + ) as download_fd: + assert get_file_hash(upload_fd) == get_file_hash(download_fd) + + with tempfile.TemporaryDirectory() as download_dir: + with ThreadPoolExecutor(max_workers=10) as executor: + for file_path in resources_files: + executor.submit(put_and_get_file, file_path, download_dir) From a083989df8663f87d1b5fc0fed62b75d5fa1b576 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 17:20:04 -0700 Subject: [PATCH 10/20] make session._runtime_version_from_requirement safe --- src/snowflake/snowpark/_internal/udf_utils.py | 24 +++++++------------ src/snowflake/snowpark/stored_procedure.py | 11 ++++++--- src/snowflake/snowpark/udaf.py | 9 +++++-- src/snowflake/snowpark/udf.py | 9 +++++-- src/snowflake/snowpark/udtf.py | 9 +++++-- 5 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 4012ad9ff6e..2c36c8e6091 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1122,13 +1122,10 @@ def resolve_imports_and_packages( # Upload closure to stage if it is beyond inline closure size limit handler = inline_code = upload_file_stage_location = None - custom_python_runtime_version_allowed = False + # As cloudpickle is being used, we cannot allow a custom runtime + custom_python_runtime_version_allowed = not isinstance(func, Callable) if session is not None: if isinstance(func, Callable): - custom_python_runtime_version_allowed = ( - False # As cloudpickle is being used, we cannot allow a custom runtime - ) - # generate a random name for udf py file # and we compress it first then upload it udf_file_name_base = f"udf_py_{random_number()}" @@ -1173,7 +1170,6 @@ def resolve_imports_and_packages( upload_file_stage_location = None handler = _DEFAULT_HANDLER_NAME else: - custom_python_runtime_version_allowed = True udf_file_name = os.path.basename(func[0]) # for a compressed file, it might have multiple extensions # and we should remove all extensions @@ -1198,11 +1194,6 @@ def resolve_imports_and_packages( skip_upload_on_content_match=skip_upload_on_content_match, ) all_urls.append(upload_file_stage_location) - else: - if isinstance(func, Callable): - custom_python_runtime_version_allowed = False - else: - custom_python_runtime_version_allowed = True # build imports and packages string all_imports = ",".join( @@ -1245,12 +1236,13 @@ def create_python_udf_or_sp( statement_params: Optional[Dict[str, str]] = None, comment: Optional[str] = None, native_app_params: Optional[Dict[str, Any]] = None, + runtime_version: Optional[str] = None, ) -> None: - with session._lock: - if session is not None and session._runtime_version_from_requirement: - runtime_version = session._runtime_version_from_requirement - else: - runtime_version = f"{sys.version_info[0]}.{sys.version_info[1]}" + runtime_version = ( + f"{sys.version_info[0]}.{sys.version_info[1]}" + if not runtime_version + else runtime_version + ) if replace and if_not_exists: raise ValueError("options replace and if_not_exists are incompatible") diff --git a/src/snowflake/snowpark/stored_procedure.py b/src/snowflake/snowpark/stored_procedure.py index 92e63f6a76a..c2fcb024196 100644 --- a/src/snowflake/snowpark/stored_procedure.py +++ b/src/snowflake/snowpark/stored_procedure.py @@ -820,11 +820,15 @@ def _do_register_sp( force_inline_code=force_inline_code, ) - if (not custom_python_runtime_version_allowed) and (self._session is not None): - check_python_runtime_version( + runtime_version_from_requirement = None + if self._session is not None: + runtime_version_from_requirement = ( self._session._runtime_version_from_requirement ) + if not custom_python_runtime_version_allowed: + check_python_runtime_version(runtime_version_from_requirement) + anonymous_sp_sql = None if anonymous: anonymous_sp_sql = generate_anonymous_python_sp_sql( @@ -838,7 +842,7 @@ def _do_register_sp( raw_imports=imports, inline_python_code=code, strict=strict, - runtime_version=self._session._runtime_version_from_requirement, + runtime_version=runtime_version_from_requirement, external_access_integrations=external_access_integrations, secrets=secrets, native_app_params=native_app_params, @@ -870,6 +874,7 @@ def _do_register_sp( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a stored procedure # (e.g., a dependency might not be found on the stage), diff --git a/src/snowflake/snowpark/udaf.py b/src/snowflake/snowpark/udaf.py index fdc3d555281..889b5b62915 100644 --- a/src/snowflake/snowpark/udaf.py +++ b/src/snowflake/snowpark/udaf.py @@ -680,11 +680,15 @@ def _do_register_udaf( is_permanent=is_permanent, ) - if (not custom_python_runtime_version_allowed) and (self._session is not None): - check_python_runtime_version( + runtime_version_from_requirement = None + if self._session is not None: + runtime_version_from_requirement = ( self._session._runtime_version_from_requirement ) + if not custom_python_runtime_version_allowed: + check_python_runtime_version(runtime_version_from_requirement) + raised = False try: create_python_udf_or_sp( @@ -710,6 +714,7 @@ def _do_register_udaf( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a udaf # (e.g., a dependency might not be found on the stage), diff --git a/src/snowflake/snowpark/udf.py b/src/snowflake/snowpark/udf.py index 74278e48d30..b71a40263fd 100644 --- a/src/snowflake/snowpark/udf.py +++ b/src/snowflake/snowpark/udf.py @@ -873,11 +873,15 @@ def _do_register_udf( is_permanent=is_permanent, ) - if (not custom_python_runtime_version_allowed) and (self._session is not None): - check_python_runtime_version( + runtime_version_from_requirement = None + if self._session is not None: + runtime_version_from_requirement = ( self._session._runtime_version_from_requirement ) + if not custom_python_runtime_version_allowed: + check_python_runtime_version(runtime_version_from_requirement) + raised = False try: create_python_udf_or_sp( @@ -905,6 +909,7 @@ def _do_register_udf( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a udf # (e.g., a dependency might not be found on the stage), diff --git a/src/snowflake/snowpark/udtf.py b/src/snowflake/snowpark/udtf.py index 03b71aa5f75..856007cdb64 100644 --- a/src/snowflake/snowpark/udtf.py +++ b/src/snowflake/snowpark/udtf.py @@ -934,11 +934,15 @@ def _do_register_udtf( is_permanent=is_permanent, ) - if (not custom_python_runtime_version_allowed) and (self._session is not None): - check_python_runtime_version( + runtime_version_from_requirement = None + if self._session is not None: + runtime_version_from_requirement = ( self._session._runtime_version_from_requirement ) + if not custom_python_runtime_version_allowed: + check_python_runtime_version(runtime_version_from_requirement) + raised = False try: create_python_udf_or_sp( @@ -966,6 +970,7 @@ def _do_register_udtf( statement_params=statement_params, comment=comment, native_app_params=native_app_params, + runtime_version=runtime_version_from_requirement, ) # an exception might happen during registering a udtf # (e.g., a dependency might not be found on the stage), From 947d3847498e5749f759bff597d93dea40561b37 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 10:35:04 -0700 Subject: [PATCH 11/20] add sp/udf concurrent tests --- src/snowflake/snowpark/session.py | 2 +- tests/integ/test_multithreading.py | 218 +++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 874ff5b72b6..a233ce26683 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -944,8 +944,8 @@ def clear_imports(self) -> None: with self._lock: self._import_paths.clear() + @staticmethod def _resolve_import_path( - self, path: str, import_path: Optional[str] = None, chunk_size: int = 8192, diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 10fcc6ef70d..dd2c1da54e5 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -10,6 +10,18 @@ import pytest +from snowflake.snowpark.types import IntegerType + +try: + import dateutil + + # six is the dependency of dateutil + import six + + is_dateutil_available = True +except ImportError: + is_dateutil_available = False + from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils @@ -189,3 +201,209 @@ def put_and_get_file(upload_file_path, download_dir): with ThreadPoolExecutor(max_workers=10) as executor: for file_path in resources_files: executor.submit(put_and_get_file, file_path, download_dir) + + +def test_concurrent_add_packages(session): + # this is a list of packages available in snowflake anaconda. If this + # test fails due to packages not being available, please update the list + package_list = { + "graphviz", + "numpy", + "pandas", + "scipy", + "scikit-learn", + "matplotlib", + } + + try: + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [ + executor.submit(session.add_packages, package) + for package in package_list + ] + + for future in as_completed(futures): + future.result() + + assert session.get_packages() == { + package: package for package in package_list + } + finally: + session.clear_packages() + + +def test_concurrent_remove_package(session): + def remove_package(session_, package_name): + try: + session_.remove_package(package_name) + return True + except ValueError: + return False + except Exception as e: + raise e + + try: + session.add_packages("numpy") + with ThreadPoolExecutor(max_workers=10) as executor: + + futures = [ + executor.submit(remove_package, session, "numpy") for _ in range(10) + ] + success_count, failure_count = 0, 0 + for future in as_completed(futures): + if future.result(): + success_count += 1 + else: + failure_count += 1 + + # assert that only one thread was able to remove the package + assert success_count == 1 + assert failure_count == 9 + finally: + session.clear_packages() + + +@pytest.mark.skipif(not is_dateutil_available, reason="dateutil is not available") +def test_concurrent_add_import(session, resources_path): + test_files = TestFiles(resources_path) + import_files = [ + test_files.test_udf_py_file, + os.path.relpath(test_files.test_udf_py_file), + test_files.test_udf_directory, + os.path.relpath(test_files.test_udf_directory), + six.__file__, + os.path.relpath(six.__file__), + os.path.dirname(dateutil.__file__), + ] + try: + with ThreadPoolExecutor(max_workers=10) as executor: + for file in import_files: + executor.submit( + session.add_import, + file, + ) + + assert set(session.get_imports()) == { + os.path.abspath(file) for file in import_files + } + finally: + session.clear_imports() + + +def test_concurrent_remove_import(session, resources_path): + test_files = TestFiles(resources_path) + + def remove_import(session_, import_file): + try: + session_.remove_import(import_file) + return True + except KeyError: + return False + except Exception as e: + raise e + + try: + session.add_import(test_files.test_udf_py_file) + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [ + executor.submit(remove_import, session, test_files.test_udf_py_file) + for _ in range(10) + ] + + success_count, failure_count = 0, 0 + for future in as_completed(futures): + if future.result(): + success_count += 1 + else: + failure_count += 1 + + # assert that only one thread was able to remove the import + assert success_count == 1 + assert failure_count == 9 + finally: + session.clear_imports() + + +def test_concurrent_sp_register(session, tmpdir): + try: + session.add_packages("snowflake-snowpark-python") + + def register_and_test_sp(session_, thread_id): + prefix = Utils.random_alphanumeric_str(10) + sp_file_path = os.path.join(tmpdir, f"{prefix}_add_{thread_id}.py") + sproc_body = f""" +from snowflake.snowpark import Session +from snowflake.snowpark.functions import ( + col, + lit +) +def add_{thread_id}(session_: Session, x: int) -> int: + return ( + session_.create_dataframe([[x, ]], schema=["x"]) + .select(col("x") + lit({thread_id})) + .collect()[0][0] + ) +""" + with open(sp_file_path, "w") as f: + f.write(sproc_body) + f.flush() + + add_sp_from_file = session_.sproc.register_from_file( + sp_file_path, f"add_{thread_id}" + ) + add_sp = session_.sproc.register( + lambda sess_, x: sess_.sql(f"select {x} + {thread_id}").collect()[0][0], + return_type=IntegerType(), + input_types=[IntegerType()], + ) + + assert add_sp_from_file(1) == thread_id + 1 + assert add_sp(1) == thread_id + 1 + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(register_and_test_sp, session, i) + finally: + session.clear_packages() + + +def test_concurrent_udf_register(session, tmpdir): + df = session.range(-5, 5).to_df("a") + + def register_and_test_udf(session_, thread_id): + prefix = Utils.random_alphanumeric_str(10) + file_path = os.path.join(tmpdir, f"{prefix}_add_{thread_id}.py") + with open(file_path, "w") as f: + func = f""" +def add_{thread_id}(x: int) -> int: + return x + {thread_id} +""" + f.write(func) + f.flush() + add_i_udf_from_file = session_.udf.register_from_file( + file_path, f"add_{thread_id}" + ) + add_i_udf = session_.udf.register( + lambda x: x + thread_id, + return_type=IntegerType(), + input_types=[IntegerType()], + ) + + Utils.check_answer( + df.select(add_i_udf(df.a), add_i_udf_from_file(df.a)), + [(thread_id + i, thread_id + i) for i in range(-5, 5)], + ) + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(register_and_test_udf, session, i) + + +@pytest.mark.parametrize("from_file", [True, False]) +def test_concurrent_udtf_register(session, from_file): + pass + + +@pytest.mark.parametrize("from_file", [True, False]) +def test_concurrent_udaf_register(session, from_file): + pass From fd51720ae43bc8112082c576707b35eddb7fd8a4 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 10:47:34 -0700 Subject: [PATCH 12/20] fix broken test --- src/snowflake/snowpark/_internal/udf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 2c36c8e6091..07635c8de8a 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1076,6 +1076,7 @@ def resolve_imports_and_packages( ) ) + all_urls = [] if session is not None: import_only_stage = ( unwrap_stage_location_single_quote(stage_location) @@ -1089,7 +1090,6 @@ def resolve_imports_and_packages( else session.get_session_stage(statement_params=statement_params) ) - all_urls = [] if imports: udf_level_imports = {} for udf_import in imports: From 3077853b6e89bf19b9352fbee02a90c5c59c5023 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 13:43:28 -0700 Subject: [PATCH 13/20] add udtf/udaf tests --- tests/integ/test_multithreading.py | 101 +++++++++++++++++++++++++++-- 1 file changed, 95 insertions(+), 6 deletions(-) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index dd2c1da54e5..9f8b5bef293 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -6,10 +6,12 @@ import os import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Tuple # noqa: F401 from unittest.mock import patch import pytest +from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType try: @@ -399,11 +401,98 @@ def add_{thread_id}(x: int) -> int: executor.submit(register_and_test_udf, session, i) -@pytest.mark.parametrize("from_file", [True, False]) -def test_concurrent_udtf_register(session, from_file): - pass +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="UDTFs is not supported in local testing mode", + run=False, +) +def test_concurrent_udtf_register(session, tmpdir): + def register_and_test_udtf(session_, thread_id): + udtf_body = f""" +from typing import List, Tuple + +class UDTFEcho: + def process( + self, + num: int, + ) -> List[Tuple[int]]: + return [(num + {thread_id},)] +""" + prefix = Utils.random_alphanumeric_str(10) + file_path = os.path.join(tmpdir, f"{prefix}_udtf_echo_{thread_id}.py") + with open(file_path, "w") as f: + f.write(udtf_body) + f.flush() + + d = {} + exec(udtf_body, {**globals(), **locals()}, d) + echo_udtf_from_file = session_.udtf.register_from_file( + file_path, "UDTFEcho", output_schema=["num"] + ) + echo_udtf = session_.udtf.register(d["UDTFEcho"], output_schema=["num"]) + + df_local = session.table_function(echo_udtf(lit(1))) + df_from_file = session.table_function(echo_udtf_from_file(lit(1))) + assert df_local.collect() == [(thread_id + 1,)] + assert df_from_file.collect() == [(thread_id + 1,)] + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(register_and_test_udtf, session, i) + + +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="UDAFs is not supported in local testing mode", + run=False, +) +def test_concurrent_udaf_register(session: Session, tmpdir): + df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b") + def register_and_test_udaf(session_, thread_id): + udaf_body = f""" +class OffsetSumUDAFHandler: + def __init__(self) -> None: + self._sum = 0 -@pytest.mark.parametrize("from_file", [True, False]) -def test_concurrent_udaf_register(session, from_file): - pass + @property + def aggregate_state(self): + return self._sum + + def accumulate(self, input_value): + self._sum += input_value + + def merge(self, other_sum): + self._sum += other_sum + + def finish(self): + return self._sum + {thread_id} + """ + prefix = Utils.random_alphanumeric_str(10) + file_path = os.path.join(tmpdir, f"{prefix}_udaf_{thread_id}.py") + with open(file_path, "w") as f: + f.write(udaf_body) + f.flush() + d = {} + exec(udaf_body, {**globals(), **locals()}, d) + + offset_sum_udaf_from_file = session_.udaf.register_from_file( + file_path, + "OffsetSumUDAFHandler", + return_type=IntegerType(), + input_types=[IntegerType()], + ) + offset_sum_udaf = session_.udaf.register( + d["OffsetSumUDAFHandler"], + return_type=IntegerType(), + input_types=[IntegerType()], + ) + + Utils.check_answer( + df.agg(offset_sum_udaf_from_file(df.a)), [Row(6 + thread_id)] + ) + Utils.check_answer(df.agg(offset_sum_udaf(df.a)), [Row(6 + thread_id)]) + + with ThreadPoolExecutor(max_workers=10) as executor: + for i in range(10): + executor.submit(register_and_test_udaf, session, i) From 65c3186437f17ad29af91096cec5be29424bd6f4 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 13:51:53 -0700 Subject: [PATCH 14/20] fix broken test --- tests/unit/test_udf_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index 12a49443539..c23755c14a3 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -5,6 +5,7 @@ import logging import os import pickle +import threading from unittest import mock import pytest @@ -249,6 +250,7 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): "random_package_one": "random_package_one", "random_package_two": "random_package_two", } + fake_session._lock = threading.RLock() result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) major, minor, patch = VERSION From 1c83ef232d297d366b5ae468e07760af4f0c352e Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:02:28 -0700 Subject: [PATCH 15/20] use _package_lock to protect Session._packages --- src/snowflake/snowpark/_internal/udf_utils.py | 2 +- src/snowflake/snowpark/session.py | 121 +++++++++--------- 2 files changed, 63 insertions(+), 60 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 07635c8de8a..bf4ce0af9b8 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -981,7 +981,7 @@ def add_snowpark_package_to_sproc_packages( if session is None: packages = [this_package] else: - with session._lock: + with session._package_lock: session_packages = session._packages.copy() if package_name not in session_packages: packages = list(session_packages.values()) + [this_package] diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a233ce26683..0cda2dc5185 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -502,6 +502,11 @@ def __init__( self._conn = conn self._thread_store = threading.local() self._lock = threading.RLock() + + # this lock is used to protect _packages. We use introduce a new lock because add_packages + # launches a query to snowflake to get all version of packages available in snowflake. This + # query can be slow and prevent other threads from moving on waiting for _lock. + self._package_lock = threading.RLock() self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} self._packages: Dict[str, str] = {} @@ -1116,7 +1121,7 @@ def get_packages(self) -> Dict[str, str]: The key of this ``dict`` is the package name and the value of this ``dict`` is the corresponding requirement specifier. """ - with self._lock: + with self._package_lock: return self._packages.copy() def add_packages( @@ -1208,7 +1213,7 @@ def remove_package(self, package: str) -> None: 0 """ package_name = pkg_resources.Requirement.parse(package).key - with self._lock: + with self._package_lock: if package_name in self._packages: self._packages.pop(package_name) else: @@ -1218,7 +1223,7 @@ def clear_packages(self) -> None: """ Clears all third-party packages of a user-defined function (UDF). """ - with self._lock: + with self._package_lock: self._packages.clear() def add_requirements(self, file_path: str) -> None: @@ -1567,25 +1572,26 @@ def _resolve_packages( if isinstance(self._conn, MockServerConnection): # in local testing we don't resolve the packages, we just return what is added errors = [] - with self._lock: - result_dict = self._packages.copy() - for pkg_name, _, pkg_req in package_dict.values(): - if pkg_name in result_dict and str(pkg_req) != result_dict[pkg_name]: - errors.append( - ValueError( - f"Cannot add package '{str(pkg_req)}' because {result_dict[pkg_name]} " - "is already added." + with self._package_lock: + result_dict = self._packages + for pkg_name, _, pkg_req in package_dict.values(): + if ( + pkg_name in result_dict + and str(pkg_req) != result_dict[pkg_name] + ): + errors.append( + ValueError( + f"Cannot add package '{str(pkg_req)}' because {result_dict[pkg_name]} " + "is already added." + ) ) - ) - else: - result_dict[pkg_name] = str(pkg_req) - if len(errors) == 1: - raise errors[0] - elif len(errors) > 0: - raise RuntimeError(errors) - - with self._lock: - self._packages.update(result_dict) + else: + result_dict[pkg_name] = str(pkg_req) + if len(errors) == 1: + raise errors[0] + elif len(errors) > 0: + raise RuntimeError(errors) + return list(result_dict.values()) package_table = "information_schema.packages" @@ -1600,50 +1606,47 @@ def _resolve_packages( # 'python-dateutil': 'python-dateutil==2.8.2'} # Add to packages dictionary. Make a copy of existing packages # dictionary to avoid modifying it during intermediate steps. - with self._lock: + with self._package_lock: result_dict = ( - existing_packages_dict.copy() - if existing_packages_dict is not None - else {} + existing_packages_dict if existing_packages_dict is not None else {} ) - # Retrieve list of dependencies that need to be added - dependency_packages = self._get_dependency_packages( - package_dict, - validate_package, - package_table, - result_dict, - statement_params=statement_params, - ) - - # Add dependency packages - for package in dependency_packages: - name = package.name - version = package.specs[0][1] if package.specs else None - - if name in result_dict: - if version is not None: - added_package_has_version = "==" in result_dict[name] - if added_package_has_version and result_dict[name] != str(package): - raise ValueError( - f"Cannot add dependency package '{name}=={version}' " - f"because {result_dict[name]} is already added." - ) + # Retrieve list of dependencies that need to be added + dependency_packages = self._get_dependency_packages( + package_dict, + validate_package, + package_table, + result_dict, + statement_params=statement_params, + ) + + # Add dependency packages + for package in dependency_packages: + name = package.name + version = package.specs[0][1] if package.specs else None + + if name in result_dict: + if version is not None: + added_package_has_version = "==" in result_dict[name] + if added_package_has_version and result_dict[name] != str( + package + ): + raise ValueError( + f"Cannot add dependency package '{name}=={version}' " + f"because {result_dict[name]} is already added." + ) + result_dict[name] = str(package) + else: result_dict[name] = str(package) - else: - result_dict[name] = str(package) - # Always include cloudpickle - extra_modules = [cloudpickle] - if include_pandas: - extra_modules.append("pandas") + # Always include cloudpickle + extra_modules = [cloudpickle] + if include_pandas: + extra_modules.append("pandas") - with self._lock: - if existing_packages_dict is not None: - existing_packages_dict.update(result_dict) - return list(result_dict.values()) + self._get_req_identifiers_list( - extra_modules, result_dict - ) + return list(result_dict.values()) + self._get_req_identifiers_list( + extra_modules, result_dict + ) def _upload_unsupported_packages( self, From a6497610be5072357c54b5c64bb8df95c65ef165 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:14:21 -0700 Subject: [PATCH 16/20] undo refactor --- src/snowflake/snowpark/_internal/udf_utils.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index bf4ce0af9b8..58d698556b3 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -982,9 +982,8 @@ def add_snowpark_package_to_sproc_packages( packages = [this_package] else: with session._package_lock: - session_packages = session._packages.copy() - if package_name not in session_packages: - packages = list(session_packages.values()) + [this_package] + if package_name not in session._packages: + packages = list(session._packages.values()) + [this_package] else: package_names = [p if isinstance(p, str) else p.__name__ for p in packages] if not any(p.startswith(package_name) for p in package_names): @@ -1076,20 +1075,19 @@ def resolve_imports_and_packages( ) ) - all_urls = [] if session is not None: import_only_stage = ( unwrap_stage_location_single_quote(stage_location) if stage_location else session.get_session_stage(statement_params=statement_params) ) - upload_and_import_stage = ( import_only_stage if is_permanent else session.get_session_stage(statement_params=statement_params) ) + if session: if imports: udf_level_imports = {} for udf_import in imports: @@ -1117,15 +1115,22 @@ def resolve_imports_and_packages( upload_and_import_stage, statement_params=statement_params, ) + else: + all_urls = [] + else: + all_urls = [] dest_prefix = get_udf_upload_prefix(udf_name) # Upload closure to stage if it is beyond inline closure size limit handler = inline_code = upload_file_stage_location = None - # As cloudpickle is being used, we cannot allow a custom runtime - custom_python_runtime_version_allowed = not isinstance(func, Callable) + custom_python_runtime_version_allowed = False if session is not None: if isinstance(func, Callable): + custom_python_runtime_version_allowed = ( + False # As cloudpickle is being used, we cannot allow a custom runtime + ) + # generate a random name for udf py file # and we compress it first then upload it udf_file_name_base = f"udf_py_{random_number()}" @@ -1170,6 +1175,7 @@ def resolve_imports_and_packages( upload_file_stage_location = None handler = _DEFAULT_HANDLER_NAME else: + custom_python_runtime_version_allowed = True udf_file_name = os.path.basename(func[0]) # for a compressed file, it might have multiple extensions # and we should remove all extensions @@ -1194,6 +1200,11 @@ def resolve_imports_and_packages( skip_upload_on_content_match=skip_upload_on_content_match, ) all_urls.append(upload_file_stage_location) + else: + if isinstance(func, Callable): + custom_python_runtime_version_allowed = False + else: + custom_python_runtime_version_allowed = True # build imports and packages string all_imports = ",".join( From f03d6186f84b4f5b593d1920d820e218347e3e3a Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:15:50 -0700 Subject: [PATCH 17/20] undo refactor --- src/snowflake/snowpark/_internal/udf_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 58d698556b3..25921fff821 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1081,6 +1081,7 @@ def resolve_imports_and_packages( if stage_location else session.get_session_stage(statement_params=statement_params) ) + upload_and_import_stage = ( import_only_stage if is_permanent From 5f398d5f00edb42ca3f19d656f656fdc91a9c99c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:33:07 -0700 Subject: [PATCH 18/20] fix test --- tests/unit/test_udf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index c23755c14a3..09e389d0c24 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -250,7 +250,7 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): "random_package_one": "random_package_one", "random_package_two": "random_package_two", } - fake_session._lock = threading.RLock() + fake_session._package_lock = threading.RLock() result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) major, minor, patch = VERSION From df3263c20c50beabe3f861c3db502ff2c8830c34 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 15:06:48 -0700 Subject: [PATCH 19/20] add file IO tests --- tests/integ/conftest.py | 23 ++++++- .../integ/scala/test_file_operation_suite.py | 17 +---- tests/integ/test_multithreading.py | 69 ++++++++++++++++++- 3 files changed, 91 insertions(+), 18 deletions(-) diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index ec619605e66..319abb137f4 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -13,7 +13,13 @@ from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.mock._connection import MockServerConnection from tests.parameters import CONNECTION_PARAMETERS -from tests.utils import TEST_SCHEMA, Utils, running_on_jenkins, running_on_public_ci +from tests.utils import ( + TEST_SCHEMA, + TestFiles, + Utils, + running_on_jenkins, + running_on_public_ci, +) def print_help() -> None: @@ -235,3 +241,18 @@ def temp_schema(connection, session, local_testing_mode) -> None: ) yield temp_schema_name cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}") + + +@pytest.fixture(scope="module") +def temp_stage(session, resources_path, local_testing_mode): + tmp_stage_name = Utils.random_stage_name() + test_files = TestFiles(resources_path) + + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) + Utils.upload_to_stage( + session, tmp_stage_name, test_files.test_file_parquet, compress=False + ) + yield tmp_stage_name + if not local_testing_mode: + Utils.drop_stage(session, tmp_stage_name) diff --git a/tests/integ/scala/test_file_operation_suite.py b/tests/integ/scala/test_file_operation_suite.py index 2dc424dde09..82a1722a729 100644 --- a/tests/integ/scala/test_file_operation_suite.py +++ b/tests/integ/scala/test_file_operation_suite.py @@ -14,7 +14,7 @@ SnowparkSQLException, SnowparkUploadFileException, ) -from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, TestFiles, Utils +from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, Utils def random_alphanumeric_name(): @@ -74,21 +74,6 @@ def path4(temp_source_directory): yield filename -@pytest.fixture(scope="module") -def temp_stage(session, resources_path, local_testing_mode): - tmp_stage_name = Utils.random_stage_name() - test_files = TestFiles(resources_path) - - if not local_testing_mode: - Utils.create_stage(session, tmp_stage_name, is_temporary=True) - Utils.upload_to_stage( - session, tmp_stage_name, test_files.test_file_parquet, compress=False - ) - yield tmp_stage_name - if not local_testing_mode: - Utils.drop_stage(session, tmp_stage_name) - - def test_put_with_one_file( session, temp_stage, path1, path2, path3, local_testing_mode ): diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 164a4b7b590..10fcc6ef70d 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -2,6 +2,9 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import hashlib +import os +import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed from unittest.mock import patch @@ -9,7 +12,7 @@ from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row -from tests.utils import IS_IN_STORED_PROC, Utils +from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils def test_concurrent_select_queries(session): @@ -122,3 +125,67 @@ def test_action_ids_are_unique(session): action_ids.add(future.result()) assert len(action_ids) == 10 + + +@pytest.mark.parametrize("use_stream", [True, False]) +def test_file_io(session, resources_path, temp_stage, use_stream): + stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}" + stage_with_prefix = f"@{temp_stage}/{stage_prefix}/" + test_files = TestFiles(resources_path) + + resources_files = [ + test_files.test_file_csv, + test_files.test_file2_csv, + test_files.test_file_json, + test_files.test_file_csv_header, + test_files.test_file_csv_colon, + test_files.test_file_csv_quotes, + test_files.test_file_csv_special_format, + test_files.test_file_json_special_format, + test_files.test_file_csv_quotes_special, + test_files.test_concat_file1_csv, + test_files.test_concat_file2_csv, + ] + + def get_file_hash(fd): + return hashlib.md5(fd.read()).hexdigest() + + def put_and_get_file(upload_file_path, download_dir): + if use_stream: + with open(upload_file_path, "rb") as fd: + results = session.file.put_stream( + fd, stage_with_prefix, auto_compress=False, overwrite=False + ) + else: + results = session.file.put( + upload_file_path, + stage_with_prefix, + auto_compress=False, + overwrite=False, + ) + # assert file is uploaded successfully + assert len(results) == 1 + assert results[0].status == "UPLOADED" + + stage_file_name = f"{stage_with_prefix}{os.path.basename(upload_file_path)}" + if use_stream: + fd = session.file.get_stream(stage_file_name, download_dir) + with open(upload_file_path, "rb") as upload_fd: + assert get_file_hash(upload_fd) == get_file_hash(fd) + + else: + results = session.file.get(stage_file_name, download_dir) + # assert file is downloaded successfully + assert len(results) == 1 + assert results[0].status == "DOWNLOADED" + download_file_path = results[0].file + # assert two files are identical + with open(upload_file_path, "rb") as upload_fd, open( + download_file_path, "rb" + ) as download_fd: + assert get_file_hash(upload_fd) == get_file_hash(download_fd) + + with tempfile.TemporaryDirectory() as download_dir: + with ThreadPoolExecutor(max_workers=10) as executor: + for file_path in resources_files: + executor.submit(put_and_get_file, file_path, download_dir) From 79af9d78291c23ca74ace885f3cf60bc30dabf6c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 18 Sep 2024 15:41:14 -0700 Subject: [PATCH 20/20] add suggested test --- tests/integ/test_multithreading.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 10fcc6ef70d..b704fe5e811 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -189,3 +189,9 @@ def put_and_get_file(upload_file_path, download_dir): with ThreadPoolExecutor(max_workers=10) as executor: for file_path in resources_files: executor.submit(put_and_get_file, file_path, download_dir) + + if not use_stream: + # assert all files are downloaded + assert set(os.listdir(download_dir)) == { + os.path.basename(file_path) for file_path in resources_files + }