Skip to content

Commit

Permalink
SNOW-1418500: Add side effect to _resolve_packages (#2174)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Aug 28, 2024
1 parent e83e700 commit f11f6ee
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,15 +1062,15 @@ def resolve_imports_and_packages(
packages,
include_pandas=is_pandas_udf,
statement_params=statement_params,
)[0]
)
if packages is not None
else session._resolve_packages(
[],
session._packages,
validate_package=False,
include_pandas=is_pandas_udf,
statement_params=statement_params,
)[0]
)
)

if session is not None:
Expand Down
21 changes: 11 additions & 10 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,11 +1117,10 @@ def add_packages(
to ensure the consistent experience of a UDF between your local environment
and the Snowflake server.
"""
_, resolved_result_dict = self._resolve_packages(
self._resolve_packages(
parse_positional_args_to_list(*packages),
self._packages,
)
self._packages.update(resolved_result_dict)

def remove_package(self, package: str) -> None:
"""
Expand Down Expand Up @@ -1482,12 +1481,13 @@ def _resolve_packages(
validate_package: bool = True,
include_pandas: bool = False,
statement_params: Optional[Dict[str, str]] = None,
) -> Tuple[List[str], Dict[str, str]]:
) -> List[str]:
"""
Given a list of packages to add, this method will
1. Check if the packages are supported by Snowflake
2. Check if the package version if provided is supported by Snowflake
3. Check if the package is already added
4. Update existing packages dictionary with the new packages (*this is required for python sp to work*)
When auto package upload is enabled, this method will also try to upload the packages
unavailable in Snowflake to the stage.
Expand All @@ -1496,7 +1496,6 @@ def _resolve_packages(
Returns:
List[str]: List of package specifiers
Dict[str, str]: Dictionary of package name -> package specifier
"""
# Extract package names, whether they are local, and their associated Requirement objects
package_dict = self._parse_packages(packages)
Expand All @@ -1518,7 +1517,9 @@ def _resolve_packages(
raise errors[0]
elif len(errors) > 0:
raise RuntimeError(errors)
return list(result_dict.values()), result_dict

self._packages.update(result_dict)
return list(result_dict.values())

package_table = "information_schema.packages"
if not self.get_current_database():
Expand All @@ -1531,7 +1532,7 @@ def _resolve_packages(
# 'scikit-learn': 'scikit-learn==1.2.2',
# 'python-dateutil': 'python-dateutil==2.8.2'}
# Add to packages dictionary. Make a copy of existing packages
# dictionary to avoid modifying it.
# dictionary to avoid modifying it during intermediate steps.
result_dict = (
existing_packages_dict.copy() if existing_packages_dict is not None else {}
)
Expand Down Expand Up @@ -1567,10 +1568,10 @@ def _resolve_packages(
if include_pandas:
extra_modules.append("pandas")

return (
list(result_dict.values())
+ self._get_req_identifiers_list(extra_modules, result_dict),
result_dict,
if existing_packages_dict is not None:
existing_packages_dict.update(result_dict)
return list(result_dict.values()) + self._get_req_identifiers_list(
extra_modules, result_dict
)

def _upload_unsupported_packages(
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def is_yaml_available() -> bool:
# add module objects
# but we can't register a udf with these versions
# because the server might not have them
resolved_packages, _ = session._resolve_packages(
resolved_packages = session._resolve_packages(
[numpy, pandas, dateutil], validate_package=False
)
assert f"numpy=={numpy.__version__}" in resolved_packages
Expand Down
14 changes: 10 additions & 4 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ def run_query(sql: str):
)


def test_resolve_packages_no_side_effect():
def test_resolve_packages_side_effect():
"""Python stored procedure depends on this behavior to add packages to the session."""

def mock_get_information_schema_packages(table_name: str):
result = MagicMock()
result.filter().group_by().agg()._internal_collect_with_tag.return_value = [
Expand All @@ -261,15 +263,19 @@ def mock_get_information_schema_packages(table_name: str):

existing_packages = {}

resolved_packages, _ = session._resolve_packages(
resolved_packages = session._resolve_packages(
["random_package_name"],
existing_packages_dict=existing_packages,
validate_package=True,
include_pandas=False,
)

assert len(resolved_packages) == 2 # random_package_name and cloudpickle
assert len(existing_packages) == 0
assert (
len(resolved_packages) == 2
), resolved_packages # random_package_name and cloudpickle
assert (
len(existing_packages) == 1
), existing_packages # {"random_package_name": "random_package_name"}


@pytest.mark.skipif(not is_pandas_available, reason="requires pandas for write_pandas")
Expand Down

0 comments on commit f11f6ee

Please sign in to comment.