From f11f6ee57951ab655a3a97205821d7d2380b6f2b Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 28 Aug 2024 21:26:22 +0000 Subject: [PATCH] SNOW-1418500: Add side effect to _resolve_packages (#2174) --- src/snowflake/snowpark/_internal/udf_utils.py | 4 ++-- src/snowflake/snowpark/session.py | 21 ++++++++++--------- tests/integ/test_packaging.py | 2 +- tests/unit/test_session.py | 14 +++++++++---- 4 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 5a92dcb95cd..b79fcdcf9c9 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1062,7 +1062,7 @@ def resolve_imports_and_packages( packages, include_pandas=is_pandas_udf, statement_params=statement_params, - )[0] + ) if packages is not None else session._resolve_packages( [], @@ -1070,7 +1070,7 @@ def resolve_imports_and_packages( validate_package=False, include_pandas=is_pandas_udf, statement_params=statement_params, - )[0] + ) ) if session is not None: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 5414d9a089d..b718364dc83 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -1117,11 +1117,10 @@ def add_packages( to ensure the consistent experience of a UDF between your local environment and the Snowflake server. """ - _, resolved_result_dict = self._resolve_packages( + self._resolve_packages( parse_positional_args_to_list(*packages), self._packages, ) - self._packages.update(resolved_result_dict) def remove_package(self, package: str) -> None: """ @@ -1482,12 +1481,13 @@ def _resolve_packages( validate_package: bool = True, include_pandas: bool = False, statement_params: Optional[Dict[str, str]] = None, - ) -> Tuple[List[str], Dict[str, str]]: + ) -> List[str]: """ Given a list of packages to add, this method will 1. Check if the packages are supported by Snowflake 2. Check if the package version if provided is supported by Snowflake 3. Check if the package is already added + 4. Update existing packages dictionary with the new packages (*this is required for python sp to work*) When auto package upload is enabled, this method will also try to upload the packages unavailable in Snowflake to the stage. @@ -1496,7 +1496,6 @@ def _resolve_packages( Returns: List[str]: List of package specifiers - Dict[str, str]: Dictionary of package name -> package specifier """ # Extract package names, whether they are local, and their associated Requirement objects package_dict = self._parse_packages(packages) @@ -1518,7 +1517,9 @@ def _resolve_packages( raise errors[0] elif len(errors) > 0: raise RuntimeError(errors) - return list(result_dict.values()), result_dict + + self._packages.update(result_dict) + return list(result_dict.values()) package_table = "information_schema.packages" if not self.get_current_database(): @@ -1531,7 +1532,7 @@ def _resolve_packages( # 'scikit-learn': 'scikit-learn==1.2.2', # 'python-dateutil': 'python-dateutil==2.8.2'} # Add to packages dictionary. Make a copy of existing packages - # dictionary to avoid modifying it. + # dictionary to avoid modifying it during intermediate steps. result_dict = ( existing_packages_dict.copy() if existing_packages_dict is not None else {} ) @@ -1567,10 +1568,10 @@ def _resolve_packages( if include_pandas: extra_modules.append("pandas") - return ( - list(result_dict.values()) - + self._get_req_identifiers_list(extra_modules, result_dict), - result_dict, + if existing_packages_dict is not None: + existing_packages_dict.update(result_dict) + return list(result_dict.values()) + self._get_req_identifiers_list( + extra_modules, result_dict ) def _upload_unsupported_packages( diff --git a/tests/integ/test_packaging.py b/tests/integ/test_packaging.py index 3deac4e80f3..eaf99534e2b 100644 --- a/tests/integ/test_packaging.py +++ b/tests/integ/test_packaging.py @@ -262,7 +262,7 @@ def is_yaml_available() -> bool: # add module objects # but we can't register a udf with these versions # because the server might not have them - resolved_packages, _ = session._resolve_packages( + resolved_packages = session._resolve_packages( [numpy, pandas, dateutil], validate_package=False ) assert f"numpy=={numpy.__version__}" in resolved_packages diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 86c8d54f7bb..262c9e82c44 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -245,7 +245,9 @@ def run_query(sql: str): ) -def test_resolve_packages_no_side_effect(): +def test_resolve_packages_side_effect(): + """Python stored procedure depends on this behavior to add packages to the session.""" + def mock_get_information_schema_packages(table_name: str): result = MagicMock() result.filter().group_by().agg()._internal_collect_with_tag.return_value = [ @@ -261,15 +263,19 @@ def mock_get_information_schema_packages(table_name: str): existing_packages = {} - resolved_packages, _ = session._resolve_packages( + resolved_packages = session._resolve_packages( ["random_package_name"], existing_packages_dict=existing_packages, validate_package=True, include_pandas=False, ) - assert len(resolved_packages) == 2 # random_package_name and cloudpickle - assert len(existing_packages) == 0 + assert ( + len(resolved_packages) == 2 + ), resolved_packages # random_package_name and cloudpickle + assert ( + len(existing_packages) == 1 + ), existing_packages # {"random_package_name": "random_package_name"} @pytest.mark.skipif(not is_pandas_available, reason="requires pandas for write_pandas")