diff --git a/CHANGELOG.md b/CHANGELOG.md index 7678ec90081..a3151b1144e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ - The `format` argument changed from optional to required. - The returned result changed from a date object to a date-formatted string. +### Bug Fixes + +- Fixed a bug that `session.add_packages` can not handle requirement specifier that contains project name with underscore and version. + ## 1.9.0 (2023-10-13) ### New Features diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 345ac90e133..9c9eb714ad4 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -8,6 +8,7 @@ import json import logging import os +import re import sys import tempfile from array import array @@ -1034,10 +1035,18 @@ def _resolve_packages( # get the standard package name if there is no underscore # underscores are discouraged in package names, but are still used in Anaconda channel # pkg_resources.Requirement.parse will convert all underscores to dashes + # the regexp is to deal with case that "_" is in the package requirement as well as version restrictions + # we only extract the valid package name from the string by following: + # https://packaging.python.org/en/latest/specifications/name-normalization/ + # A valid name consists only of ASCII letters and numbers, period, underscore and hyphen. + # It must start and end with a letter or number. + # however, we don't validate the pkg name as this is done by pkg_resources.Requirement.parse + # find the index of the first char which is not an valid package name character + package_name = package_req.key + if not use_local_version and "_" in package: + reg_match = re.search(r"[^0-9a-zA-Z\-_.]", package) + package_name = package[: reg_match.start()] if reg_match else package - package_name = ( - package if not use_local_version and "_" in package else package_req.key - ) package_dict[package] = (package_name, use_local_version, package_req) package_table = "information_schema.packages" diff --git a/tests/integ/test_packaging.py b/tests/integ/test_packaging.py index e9e99f0dbcf..819baf3482c 100644 --- a/tests/integ/test_packaging.py +++ b/tests/integ/test_packaging.py @@ -283,6 +283,27 @@ def check_if_package_installed() -> bool: Utils.check_answer(session.sql(f"select {udf_name}()").collect(), [Row(True)]) +@pytest.mark.udf +def test_add_packages_with_underscore_and_versions(session): + session.add_packages(["huggingface_hub==0.15.1"]) + assert session.get_packages() == { + "huggingface_hub": "huggingface_hub==0.15.1", + } + session.clear_packages() + + session.add_packages(["huggingface_hub>0.14.1"]) + assert session.get_packages() == { + "huggingface_hub": "huggingface_hub>0.14.1", + } + session.clear_packages() + + session.add_packages(["huggingface_hub<=0.15.1"]) + assert session.get_packages() == { + "huggingface_hub": "huggingface_hub<=0.15.1", + } + session.clear_packages() + + @pytest.mark.skipif( IS_IN_STORED_PROC, reason="Need certain version of datautil/pandas/numpy" )