Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1418500: Add side effect to _resolve_packages #2174

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,15 +1062,15 @@ def resolve_imports_and_packages(
packages,
include_pandas=is_pandas_udf,
statement_params=statement_params,
)[0]
)
if packages is not None
else session._resolve_packages(
[],
session._packages,
validate_package=False,
include_pandas=is_pandas_udf,
statement_params=statement_params,
)[0]
)
)

if session is not None:
Expand Down
21 changes: 11 additions & 10 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,11 +1117,10 @@ def add_packages(
to ensure the consistent experience of a UDF between your local environment
and the Snowflake server.
"""
_, resolved_result_dict = self._resolve_packages(
self._resolve_packages(
parse_positional_args_to_list(*packages),
self._packages,
)
self._packages.update(resolved_result_dict)

def remove_package(self, package: str) -> None:
"""
Expand Down Expand Up @@ -1482,12 +1481,13 @@ def _resolve_packages(
validate_package: bool = True,
include_pandas: bool = False,
statement_params: Optional[Dict[str, str]] = None,
) -> Tuple[List[str], Dict[str, str]]:
) -> List[str]:
"""
Given a list of packages to add, this method will
1. Check if the packages are supported by Snowflake
2. Check if the package version if provided is supported by Snowflake
3. Check if the package is already added
4. Update existing packages dictionary with the new packages (*this is required for python sp to work*)

When auto package upload is enabled, this method will also try to upload the packages
unavailable in Snowflake to the stage.
Expand All @@ -1496,7 +1496,6 @@ def _resolve_packages(

Returns:
List[str]: List of package specifiers
Dict[str, str]: Dictionary of package name -> package specifier
"""
# Extract package names, whether they are local, and their associated Requirement objects
package_dict = self._parse_packages(packages)
Expand All @@ -1518,7 +1517,9 @@ def _resolve_packages(
raise errors[0]
elif len(errors) > 0:
raise RuntimeError(errors)
return list(result_dict.values()), result_dict

self._packages.update(result_dict)
return list(result_dict.values())

package_table = "information_schema.packages"
if not self.get_current_database():
Expand All @@ -1531,7 +1532,7 @@ def _resolve_packages(
# 'scikit-learn': 'scikit-learn==1.2.2',
# 'python-dateutil': 'python-dateutil==2.8.2'}
# Add to packages dictionary. Make a copy of existing packages
# dictionary to avoid modifying it.
# dictionary to avoid modifying it during intermediate steps.
result_dict = (
existing_packages_dict.copy() if existing_packages_dict is not None else {}
)
Expand Down Expand Up @@ -1567,10 +1568,10 @@ def _resolve_packages(
if include_pandas:
extra_modules.append("pandas")

return (
list(result_dict.values())
+ self._get_req_identifiers_list(extra_modules, result_dict),
result_dict,
if existing_packages_dict is not None:
existing_packages_dict.update(result_dict)
return list(result_dict.values()) + self._get_req_identifiers_list(
extra_modules, result_dict
)

def _upload_unsupported_packages(
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def is_yaml_available() -> bool:
# add module objects
# but we can't register a udf with these versions
# because the server might not have them
resolved_packages, _ = session._resolve_packages(
resolved_packages = session._resolve_packages(
[numpy, pandas, dateutil], validate_package=False
)
assert f"numpy=={numpy.__version__}" in resolved_packages
Expand Down
14 changes: 10 additions & 4 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ def run_query(sql: str):
)


def test_resolve_packages_no_side_effect():
def test_resolve_packages_side_effect():
"""Python stored procedure depends on this behavior to add packages to the session."""

def mock_get_information_schema_packages(table_name: str):
result = MagicMock()
result.filter().group_by().agg()._internal_collect_with_tag.return_value = [
Expand All @@ -261,15 +263,19 @@ def mock_get_information_schema_packages(table_name: str):

existing_packages = {}

resolved_packages, _ = session._resolve_packages(
resolved_packages = session._resolve_packages(
["random_package_name"],
existing_packages_dict=existing_packages,
validate_package=True,
include_pandas=False,
)

assert len(resolved_packages) == 2 # random_package_name and cloudpickle
assert len(existing_packages) == 0
assert (
len(resolved_packages) == 2
), resolved_packages # random_package_name and cloudpickle
assert (
len(existing_packages) == 1
), existing_packages # {"random_package_name": "random_package_name"}


@pytest.mark.skipif(not is_pandas_available, reason="requires pandas for write_pandas")
Expand Down
Loading