Skip to content

Commit

Permalink
use _package_lock to protect Session._packages
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam committed Sep 17, 2024
1 parent 65c3186 commit 1c83ef2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ def add_snowpark_package_to_sproc_packages(
if session is None:
packages = [this_package]
else:
with session._lock:
with session._package_lock:
session_packages = session._packages.copy()
if package_name not in session_packages:
packages = list(session_packages.values()) + [this_package]
Expand Down
121 changes: 62 additions & 59 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,11 @@ def __init__(
self._conn = conn
self._thread_store = threading.local()
self._lock = threading.RLock()

# this lock is used to protect _packages. We use introduce a new lock because add_packages
# launches a query to snowflake to get all version of packages available in snowflake. This
# query can be slow and prevent other threads from moving on waiting for _lock.
self._package_lock = threading.RLock()
self._query_tag = None
self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {}
self._packages: Dict[str, str] = {}
Expand Down Expand Up @@ -1116,7 +1121,7 @@ def get_packages(self) -> Dict[str, str]:
The key of this ``dict`` is the package name and the value of this ``dict``
is the corresponding requirement specifier.
"""
with self._lock:
with self._package_lock:
return self._packages.copy()

def add_packages(
Expand Down Expand Up @@ -1208,7 +1213,7 @@ def remove_package(self, package: str) -> None:
0
"""
package_name = pkg_resources.Requirement.parse(package).key
with self._lock:
with self._package_lock:
if package_name in self._packages:
self._packages.pop(package_name)
else:
Expand All @@ -1218,7 +1223,7 @@ def clear_packages(self) -> None:
"""
Clears all third-party packages of a user-defined function (UDF).
"""
with self._lock:
with self._package_lock:
self._packages.clear()

def add_requirements(self, file_path: str) -> None:
Expand Down Expand Up @@ -1567,25 +1572,26 @@ def _resolve_packages(
if isinstance(self._conn, MockServerConnection):
# in local testing we don't resolve the packages, we just return what is added
errors = []
with self._lock:
result_dict = self._packages.copy()
for pkg_name, _, pkg_req in package_dict.values():
if pkg_name in result_dict and str(pkg_req) != result_dict[pkg_name]:
errors.append(
ValueError(
f"Cannot add package '{str(pkg_req)}' because {result_dict[pkg_name]} "
"is already added."
with self._package_lock:
result_dict = self._packages
for pkg_name, _, pkg_req in package_dict.values():
if (
pkg_name in result_dict
and str(pkg_req) != result_dict[pkg_name]
):
errors.append(
ValueError(
f"Cannot add package '{str(pkg_req)}' because {result_dict[pkg_name]} "
"is already added."
)
)
)
else:
result_dict[pkg_name] = str(pkg_req)
if len(errors) == 1:
raise errors[0]
elif len(errors) > 0:
raise RuntimeError(errors)

with self._lock:
self._packages.update(result_dict)
else:
result_dict[pkg_name] = str(pkg_req)
if len(errors) == 1:
raise errors[0]
elif len(errors) > 0:
raise RuntimeError(errors)

return list(result_dict.values())

package_table = "information_schema.packages"
Expand All @@ -1600,50 +1606,47 @@ def _resolve_packages(
# 'python-dateutil': 'python-dateutil==2.8.2'}
# Add to packages dictionary. Make a copy of existing packages
# dictionary to avoid modifying it during intermediate steps.
with self._lock:
with self._package_lock:
result_dict = (
existing_packages_dict.copy()
if existing_packages_dict is not None
else {}
existing_packages_dict if existing_packages_dict is not None else {}
)

# Retrieve list of dependencies that need to be added
dependency_packages = self._get_dependency_packages(
package_dict,
validate_package,
package_table,
result_dict,
statement_params=statement_params,
)

# Add dependency packages
for package in dependency_packages:
name = package.name
version = package.specs[0][1] if package.specs else None

if name in result_dict:
if version is not None:
added_package_has_version = "==" in result_dict[name]
if added_package_has_version and result_dict[name] != str(package):
raise ValueError(
f"Cannot add dependency package '{name}=={version}' "
f"because {result_dict[name]} is already added."
)
# Retrieve list of dependencies that need to be added
dependency_packages = self._get_dependency_packages(
package_dict,
validate_package,
package_table,
result_dict,
statement_params=statement_params,
)

# Add dependency packages
for package in dependency_packages:
name = package.name
version = package.specs[0][1] if package.specs else None

if name in result_dict:
if version is not None:
added_package_has_version = "==" in result_dict[name]
if added_package_has_version and result_dict[name] != str(
package
):
raise ValueError(
f"Cannot add dependency package '{name}=={version}' "
f"because {result_dict[name]} is already added."
)
result_dict[name] = str(package)
else:
result_dict[name] = str(package)
else:
result_dict[name] = str(package)

# Always include cloudpickle
extra_modules = [cloudpickle]
if include_pandas:
extra_modules.append("pandas")
# Always include cloudpickle
extra_modules = [cloudpickle]
if include_pandas:
extra_modules.append("pandas")

with self._lock:
if existing_packages_dict is not None:
existing_packages_dict.update(result_dict)
return list(result_dict.values()) + self._get_req_identifiers_list(
extra_modules, result_dict
)
return list(result_dict.values()) + self._get_req_identifiers_list(
extra_modules, result_dict
)

def _upload_unsupported_packages(
self,
Expand Down

0 comments on commit 1c83ef2

Please sign in to comment.