Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1418523: make udf and sproc registration thread safe #2289

Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
56fb566
init
sfc-gh-aalam Sep 11, 2024
66003d1
make udf/sproc related files thread-safe
sfc-gh-aalam Sep 11, 2024
0e58205
Merge branch 'main' into aalam-SNOW-1418523-make-udf-sproc-thread-safe
sfc-gh-aalam Sep 11, 2024
e75dde1
init
sfc-gh-aalam Sep 11, 2024
68a8c1c
make query listener thread-safe
sfc-gh-aalam Sep 11, 2024
31a5734
Fix query_tag and last_action_id
sfc-gh-aalam Sep 11, 2024
b4dadda
core updates done
sfc-gh-aalam Sep 11, 2024
b8c6496
Add tests
sfc-gh-aalam Sep 12, 2024
f39837e
Fix local tests
sfc-gh-aalam Sep 12, 2024
31a196f
Merge branch 'main' into aalam-SNOW-1418523-make-analyzer-server_conn…
sfc-gh-aalam Sep 12, 2024
723bdf7
Merge branch 'aalam-SNOW-1418523-make-internal-session-variables-thre…
sfc-gh-aalam Sep 12, 2024
37c0419
add file IO tests
sfc-gh-aalam Sep 12, 2024
8a2d433
Merge branch 'aalam-SNOW-1418523-concurrent-file-operations' into aal…
sfc-gh-aalam Sep 12, 2024
a083989
make session._runtime_version_from_requirement safe
sfc-gh-aalam Sep 13, 2024
947d384
add sp/udf concurrent tests
sfc-gh-aalam Sep 13, 2024
fd51720
fix broken test
sfc-gh-aalam Sep 13, 2024
3077853
add udtf/udaf tests
sfc-gh-aalam Sep 13, 2024
65c3186
fix broken test
sfc-gh-aalam Sep 13, 2024
1c83ef2
use _package_lock to protect Session._packages
sfc-gh-aalam Sep 17, 2024
a649761
undo refactor
sfc-gh-aalam Sep 17, 2024
f03d618
undo refactor
sfc-gh-aalam Sep 17, 2024
5f398d5
fix test
sfc-gh-aalam Sep 17, 2024
4eef3e9
Merge branch 'aalam-SNOW-1418523-make-internal-session-variables-thre…
sfc-gh-aalam Sep 17, 2024
df3263c
add file IO tests
sfc-gh-aalam Sep 12, 2024
6769c54
merge with base
sfc-gh-aalam Sep 17, 2024
79af9d7
add suggested test
sfc-gh-aalam Sep 18, 2024
ef71c05
merge with base
sfc-gh-aalam Sep 25, 2024
e553db4
merge with base
sfc-gh-aalam Sep 25, 2024
ee6ca50
Merge branch 'aalam-SNOW-1418523-make-internal-session-variables-thre…
sfc-gh-aalam Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect
import os
import sys
import threading
import time
from logging import getLogger
from typing import (
Expand Down Expand Up @@ -154,6 +155,8 @@ def __init__(
options: Dict[str, Union[int, str]],
conn: Optional[SnowflakeConnection] = None,
) -> None:
self._lock = threading.RLock()
self._thread_stored = threading.local()
self._lower_case_parameters = {k.lower(): v for k, v in options.items()}
self._add_application_parameters()
self._conn = conn if conn else connect(**self._lower_case_parameters)
Expand All @@ -170,7 +173,6 @@ def __init__(

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
self._cursor = self._conn.cursor()
self._telemetry_client = TelemetryClient(self._conn)
self._query_listener: Set[QueryHistory] = set()
# The session in this case refers to a Snowflake session, not a
Expand All @@ -183,6 +185,12 @@ def __init__(
"_skip_upload_on_content_match" in signature.parameters
)

@property
def _cursor(self) -> SnowflakeCursor:
if not hasattr(self._thread_stored, "cursor"):
self._thread_stored.cursor = self._conn.cursor()
return self._thread_stored.cursor

def _add_application_parameters(self) -> None:
if PARAM_APPLICATION not in self._lower_case_parameters:
# Mirrored from snowflake-connector-python/src/snowflake/connector/connection.py#L295
Expand Down Expand Up @@ -210,10 +218,12 @@ def _add_application_parameters(self) -> None:
] = get_version()

def add_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.add(listener)
with self._lock:
self._query_listener.add(listener)

def remove_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.remove(listener)
with self._lock:
self._query_listener.remove(listener)

def close(self) -> None:
if self._conn:
Expand Down Expand Up @@ -360,8 +370,9 @@ def upload_stream(
raise ex

def notify_query_listeners(self, query_record: QueryRecord) -> None:
for listener in self._query_listener:
listener._add_query(query_record)
with self._lock:
for listener in self._query_listener:
listener._add_query(query_record)

def execute_and_notify_query_listener(
self, query: str, **kwargs: Any
Expand Down
36 changes: 14 additions & 22 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,11 @@ def add_snowpark_package_to_sproc_packages(
if packages is None:
if session is None:
packages = [this_package]
elif package_name not in session._packages:
packages = list(session._packages.values()) + [this_package]
else:
with session._lock:
session_packages = session._packages.copy()
if package_name not in session_packages:
packages = list(session_packages.values()) + [this_package]
else:
package_names = [p if isinstance(p, str) else p.__name__ for p in packages]
if not any(p.startswith(package_name) for p in package_names):
Expand Down Expand Up @@ -1073,6 +1076,7 @@ def resolve_imports_and_packages(
)
)

all_urls = []
if session is not None:
import_only_stage = (
unwrap_stage_location_single_quote(stage_location)
Expand All @@ -1086,7 +1090,6 @@ def resolve_imports_and_packages(
else session.get_session_stage(statement_params=statement_params)
)

if session:
if imports:
udf_level_imports = {}
for udf_import in imports:
Expand Down Expand Up @@ -1114,22 +1117,15 @@ def resolve_imports_and_packages(
upload_and_import_stage,
statement_params=statement_params,
)
else:
all_urls = []
else:
all_urls = []

dest_prefix = get_udf_upload_prefix(udf_name)

# Upload closure to stage if it is beyond inline closure size limit
handler = inline_code = upload_file_stage_location = None
custom_python_runtime_version_allowed = False
# As cloudpickle is being used, we cannot allow a custom runtime
custom_python_runtime_version_allowed = not isinstance(func, Callable)
if session is not None:
if isinstance(func, Callable):
custom_python_runtime_version_allowed = (
False # As cloudpickle is being used, we cannot allow a custom runtime
)

# generate a random name for udf py file
# and we compress it first then upload it
udf_file_name_base = f"udf_py_{random_number()}"
Expand Down Expand Up @@ -1174,7 +1170,6 @@ def resolve_imports_and_packages(
upload_file_stage_location = None
handler = _DEFAULT_HANDLER_NAME
else:
custom_python_runtime_version_allowed = True
udf_file_name = os.path.basename(func[0])
# for a compressed file, it might have multiple extensions
# and we should remove all extensions
Expand All @@ -1199,11 +1194,6 @@ def resolve_imports_and_packages(
skip_upload_on_content_match=skip_upload_on_content_match,
)
all_urls.append(upload_file_stage_location)
else:
if isinstance(func, Callable):
custom_python_runtime_version_allowed = False
else:
custom_python_runtime_version_allowed = True

# build imports and packages string
all_imports = ",".join(
Expand Down Expand Up @@ -1246,11 +1236,13 @@ def create_python_udf_or_sp(
statement_params: Optional[Dict[str, str]] = None,
comment: Optional[str] = None,
native_app_params: Optional[Dict[str, Any]] = None,
runtime_version: Optional[str] = None,
) -> None:
if session is not None and session._runtime_version_from_requirement:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is to ensure the same value for runtime_version is read from session througout udf/sp registration process.

runtime_version = session._runtime_version_from_requirement
else:
runtime_version = f"{sys.version_info[0]}.{sys.version_info[1]}"
runtime_version = (
f"{sys.version_info[0]}.{sys.version_info[1]}"
if not runtime_version
else runtime_version
)

if replace and if_not_exists:
raise ValueError("options replace and if_not_exists are incompatible")
Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,12 +2027,13 @@ def _union_by_name_internal(
]

names = right_project_list + not_found_attrs
if self._session.sql_simplifier_enabled and other._select_statement:
sql_simplifier_enabled = self._session.sql_simplifier_enabled
if sql_simplifier_enabled and other._select_statement:
right_child = self._with_plan(other._select_statement.select(names))
else:
right_child = self._with_plan(Project(names, other._plan))

if self._session.sql_simplifier_enabled:
if sql_simplifier_enabled:
df = self._with_plan(
self._select_statement.set_operator(
right_child._select_statement
Expand Down
Loading
Loading