From b6114e4f055bb45b884104ca167ce92a8c1a4997 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 27 Aug 2024 13:42:31 -0700 Subject: [PATCH 1/2] Add side effect to _resolve_packages --- src/snowflake/snowpark/_internal/udf_utils.py | 4 ++-- src/snowflake/snowpark/session.py | 20 +++++++++---------- tests/integ/test_packaging.py | 2 +- tests/unit/test_session.py | 14 +++++++++---- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 5a92dcb95cd..b79fcdcf9c9 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1062,7 +1062,7 @@ def resolve_imports_and_packages( packages, include_pandas=is_pandas_udf, statement_params=statement_params, - )[0] + ) if packages is not None else session._resolve_packages( [], @@ -1070,7 +1070,7 @@ def resolve_imports_and_packages( validate_package=False, include_pandas=is_pandas_udf, statement_params=statement_params, - )[0] + ) ) if session is not None: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 5414d9a089d..a5e82baf2c2 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -1117,11 +1117,10 @@ def add_packages( to ensure the consistent experience of a UDF between your local environment and the Snowflake server. """ - _, resolved_result_dict = self._resolve_packages( + self._resolve_packages( parse_positional_args_to_list(*packages), self._packages, ) - self._packages.update(resolved_result_dict) def remove_package(self, package: str) -> None: """ @@ -1482,12 +1481,13 @@ def _resolve_packages( validate_package: bool = True, include_pandas: bool = False, statement_params: Optional[Dict[str, str]] = None, - ) -> Tuple[List[str], Dict[str, str]]: + ) -> List[str]: """ Given a list of packages to add, this method will 1. Check if the packages are supported by Snowflake 2. Check if the package version if provided is supported by Snowflake 3. Check if the package is already added + 4. Update existing packages dictionary with the new packages (*this is required for python sp to work*) When auto package upload is enabled, this method will also try to upload the packages unavailable in Snowflake to the stage. @@ -1496,7 +1496,6 @@ def _resolve_packages( Returns: List[str]: List of package specifiers - Dict[str, str]: Dictionary of package name -> package specifier """ # Extract package names, whether they are local, and their associated Requirement objects package_dict = self._parse_packages(packages) @@ -1518,7 +1517,9 @@ def _resolve_packages( raise errors[0] elif len(errors) > 0: raise RuntimeError(errors) - return list(result_dict.values()), result_dict + + self._packages.update(result_dict) + return list(result_dict.values()) package_table = "information_schema.packages" if not self.get_current_database(): @@ -1531,7 +1532,7 @@ def _resolve_packages( # 'scikit-learn': 'scikit-learn==1.2.2', # 'python-dateutil': 'python-dateutil==2.8.2'} # Add to packages dictionary. Make a copy of existing packages - # dictionary to avoid modifying it. + # dictionary to avoid modifying it during intermediate steps. result_dict = ( existing_packages_dict.copy() if existing_packages_dict is not None else {} ) @@ -1567,10 +1568,9 @@ def _resolve_packages( if include_pandas: extra_modules.append("pandas") - return ( - list(result_dict.values()) - + self._get_req_identifiers_list(extra_modules, result_dict), - result_dict, + existing_packages_dict.update(result_dict) + return list(result_dict.values()) + self._get_req_identifiers_list( + extra_modules, result_dict ) def _upload_unsupported_packages( diff --git a/tests/integ/test_packaging.py b/tests/integ/test_packaging.py index 3deac4e80f3..eaf99534e2b 100644 --- a/tests/integ/test_packaging.py +++ b/tests/integ/test_packaging.py @@ -262,7 +262,7 @@ def is_yaml_available() -> bool: # add module objects # but we can't register a udf with these versions # because the server might not have them - resolved_packages, _ = session._resolve_packages( + resolved_packages = session._resolve_packages( [numpy, pandas, dateutil], validate_package=False ) assert f"numpy=={numpy.__version__}" in resolved_packages diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 86c8d54f7bb..262c9e82c44 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -245,7 +245,9 @@ def run_query(sql: str): ) -def test_resolve_packages_no_side_effect(): +def test_resolve_packages_side_effect(): + """Python stored procedure depends on this behavior to add packages to the session.""" + def mock_get_information_schema_packages(table_name: str): result = MagicMock() result.filter().group_by().agg()._internal_collect_with_tag.return_value = [ @@ -261,15 +263,19 @@ def mock_get_information_schema_packages(table_name: str): existing_packages = {} - resolved_packages, _ = session._resolve_packages( + resolved_packages = session._resolve_packages( ["random_package_name"], existing_packages_dict=existing_packages, validate_package=True, include_pandas=False, ) - assert len(resolved_packages) == 2 # random_package_name and cloudpickle - assert len(existing_packages) == 0 + assert ( + len(resolved_packages) == 2 + ), resolved_packages # random_package_name and cloudpickle + assert ( + len(existing_packages) == 1 + ), existing_packages # {"random_package_name": "random_package_name"} @pytest.mark.skipif(not is_pandas_available, reason="requires pandas for write_pandas") From 19030911620a8464af2cb289c9a50b32787e6eb8 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 27 Aug 2024 14:00:06 -0700 Subject: [PATCH 2/2] fix test --- src/snowflake/snowpark/session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a5e82baf2c2..b718364dc83 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -1568,7 +1568,8 @@ def _resolve_packages( if include_pandas: extra_modules.append("pandas") - existing_packages_dict.update(result_dict) + if existing_packages_dict is not None: + existing_packages_dict.update(result_dict) return list(result_dict.values()) + self._get_req_identifiers_list( extra_modules, result_dict )