Skip to content

Commit

Permalink
SNOW-1418523 make analyzer server connection thread safe (#2282)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Sep 25, 2024
1 parent 96949be commit 5f140ab
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 38 deletions.
34 changes: 23 additions & 11 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect
import os
import sys
import threading
import time
from logging import getLogger
from typing import (
Expand Down Expand Up @@ -154,6 +155,8 @@ def __init__(
options: Dict[str, Union[int, str]],
conn: Optional[SnowflakeConnection] = None,
) -> None:
self._lock = threading.RLock()
self._thread_store = threading.local()
self._lower_case_parameters = {k.lower(): v for k, v in options.items()}
self._add_application_parameters()
self._conn = conn if conn else connect(**self._lower_case_parameters)
Expand All @@ -170,7 +173,6 @@ def __init__(

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
self._cursor = self._conn.cursor()
self._telemetry_client = TelemetryClient(self._conn)
self._query_listener: Set[QueryHistory] = set()
# The session in this case refers to a Snowflake session, not a
Expand All @@ -183,6 +185,12 @@ def __init__(
"_skip_upload_on_content_match" in signature.parameters
)

@property
def _cursor(self) -> SnowflakeCursor:
if not hasattr(self._thread_store, "cursor"):
self._thread_store.cursor = self._conn.cursor()
return self._thread_store.cursor

def _add_application_parameters(self) -> None:
if PARAM_APPLICATION not in self._lower_case_parameters:
# Mirrored from snowflake-connector-python/src/snowflake/connector/connection.py#L295
Expand Down Expand Up @@ -210,10 +218,12 @@ def _add_application_parameters(self) -> None:
] = get_version()

def add_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.add(listener)
with self._lock:
self._query_listener.add(listener)

def remove_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.remove(listener)
with self._lock:
self._query_listener.remove(listener)

def close(self) -> None:
if self._conn:
Expand Down Expand Up @@ -252,12 +262,13 @@ def _run_new_describe(
) -> Union[List[ResultMetadata], List["ResultMetadataV2"]]:
result_metadata = run_new_describe(cursor, query)

for listener in filter(
lambda listener: hasattr(listener, "include_describe")
and listener.include_describe,
self._query_listener,
):
listener._add_query(QueryRecord(cursor.sfqid, query, True))
with self._lock:
for listener in filter(
lambda listener: hasattr(listener, "include_describe")
and listener.include_describe,
self._query_listener,
):
listener._add_query(QueryRecord(cursor.sfqid, query, True))

return result_metadata

Expand Down Expand Up @@ -374,8 +385,9 @@ def upload_stream(
raise ex

def notify_query_listeners(self, query_record: QueryRecord) -> None:
for listener in self._query_listener:
listener._add_query(query_record)
with self._lock:
for listener in self._query_listener:
listener._add_query(query_record)

def execute_and_notify_query_listener(
self, query: str, **kwargs: Any
Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,12 +2027,13 @@ def _union_by_name_internal(
]

names = right_project_list + not_found_attrs
if self._session.sql_simplifier_enabled and other._select_statement:
sql_simplifier_enabled = self._session.sql_simplifier_enabled
if sql_simplifier_enabled and other._select_statement:
right_child = self._with_plan(other._select_statement.select(names))
else:
right_child = self._with_plan(Project(names, other._plan))

if self._session.sql_simplifier_enabled:
if sql_simplifier_enabled:
df = self._with_plan(
self._select_statement.set_operator(
right_child._select_statement
Expand Down
62 changes: 37 additions & 25 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,6 @@ def __init__(
)
self._file = FileOperation(self)
self._lineage = Lineage(self)
self._analyzer = (
Analyzer(self) if isinstance(conn, ServerConnection) else MockAnalyzer(self)
)
self._sql_simplifier_enabled: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING, True
Expand Down Expand Up @@ -623,8 +620,19 @@ def __str__(self):
)

def _generate_new_action_id(self) -> int:
self._last_action_id += 1
return self._last_action_id
with self._lock:
self._last_action_id += 1
return self._last_action_id

@property
def _analyzer(self) -> Analyzer:
if not hasattr(self._thread_store, "analyzer"):
self._thread_store.analyzer = (
Analyzer(self)
if isinstance(self._conn, ServerConnection)
else MockAnalyzer(self)
)
return self._thread_store.analyzer

def close(self) -> None:
"""Close this session."""
Expand Down Expand Up @@ -856,7 +864,8 @@ def cancel_all(self) -> None:
This does not affect any action methods called in the future.
"""
_logger.info("Canceling all running queries")
self._last_canceled_id = self._last_action_id
with self._lock:
self._last_canceled_id = self._last_action_id
if not isinstance(self._conn, MockServerConnection):
self._conn.run_query(
f"select system$cancel_all_queries({self._session_id})"
Expand Down Expand Up @@ -1958,11 +1967,12 @@ def query_tag(self) -> Optional[str]:

@query_tag.setter
def query_tag(self, tag: str) -> None:
if tag:
self._conn.run_query(f"alter session set query_tag = {str_to_sql(tag)}")
else:
self._conn.run_query("alter session unset query_tag")
self._query_tag = tag
with self._lock:
if tag:
self._conn.run_query(f"alter session set query_tag = {str_to_sql(tag)}")
else:
self._conn.run_query("alter session unset query_tag")
self._query_tag = tag

def _get_remote_query_tag(self) -> None:
"""
Expand Down Expand Up @@ -2355,18 +2365,19 @@ def get_session_stage(
Therefore, if you switch database or schema during the session, the stage will not be re-created
in the new database or schema, and still references the stage in the old database or schema.
"""
if not self._session_stage:
full_qualified_stage_name = self.get_fully_qualified_name_if_possible(
random_name_for_temp_object(TempObjectType.STAGE)
)
self._run_query(
f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \
stage if not exists {full_qualified_stage_name}",
is_ddl_on_temp_object=True,
statement_params=statement_params,
)
# set the value after running the query to ensure atomicity
self._session_stage = full_qualified_stage_name
with self._lock:
if not self._session_stage:
full_qualified_stage_name = self.get_fully_qualified_name_if_possible(
random_name_for_temp_object(TempObjectType.STAGE)
)
self._run_query(
f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \
stage if not exists {full_qualified_stage_name}",
is_ddl_on_temp_object=True,
statement_params=statement_params,
)
# set the value after running the query to ensure atomicity
self._session_stage = full_qualified_stage_name
return f"{STAGE_PREFIX}{self._session_stage}"

def _write_modin_pandas_helper(
Expand Down Expand Up @@ -3065,8 +3076,9 @@ def get_fully_qualified_name_if_possible(self, name: str) -> str:
"""
Returns the fully qualified object name if current database/schema exists, otherwise returns the object name
"""
database = self.get_current_database()
schema = self.get_current_schema()
with self._lock:
database = self.get_current_database()
schema = self.get_current_schema()
if database and schema:
return f"{database}.{schema}.{name}"

Expand Down
124 changes: 124 additions & 0 deletions tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from concurrent.futures import ThreadPoolExecutor, as_completed
from unittest.mock import patch

import pytest

from snowflake.snowpark.functions import lit
from snowflake.snowpark.row import Row
from tests.utils import IS_IN_STORED_PROC, Utils


def test_concurrent_select_queries(session):
def run_select(session_, thread_id):
df = session_.sql(f"SELECT {thread_id} as A")
assert df.collect()[0][0] == thread_id

with ThreadPoolExecutor(max_workers=10) as executor:
for i in range(10):
executor.submit(run_select, session, i)


def test_concurrent_dataframe_operations(session):
try:
table_name = Utils.random_table_name()
data = [(i, 11 * i) for i in range(10)]
df = session.create_dataframe(data, ["A", "B"])
df.write.save_as_table(table_name, table_type="temporary")

def run_dataframe_operation(session_, thread_id):
df = session_.table(table_name)
df = df.filter(df.a == lit(thread_id))
df = df.with_column("C", df.b + 100 * df.a)
df = df.rename(df.a, "D").limit(1)
return df

dfs = []
with ThreadPoolExecutor(max_workers=10) as executor:
df_futures = [
executor.submit(run_dataframe_operation, session, i) for i in range(10)
]

for future in as_completed(df_futures):
dfs.append(future.result())

main_df = dfs[0]
for df in dfs[1:]:
main_df = main_df.union(df)

Utils.check_answer(
main_df, [Row(D=i, B=11 * i, C=11 * i + 100 * i) for i in range(10)]
)

finally:
Utils.drop_table(session, table_name)


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="SQL query and query listeners are not supported",
run=False,
)
def test_query_listener(session):
def run_select(session_, thread_id):
session_.sql(f"SELECT {thread_id} as A").collect()

with session.query_history() as history:
with ThreadPoolExecutor(max_workers=10) as executor:
for i in range(10):
executor.submit(run_select, session, i)

queries_sent = [query.sql_text for query in history.queries]
assert len(queries_sent) == 10
for i in range(10):
assert f"SELECT {i} as A" in queries_sent


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="Query tag is a SQL feature",
run=False,
)
@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="show parameters is not supported in stored procedure"
)
def test_query_tagging(session):
def set_query_tag(session_, thread_id):
session_.query_tag = f"tag_{thread_id}"

with ThreadPoolExecutor(max_workers=10) as executor:
for i in range(10):
executor.submit(set_query_tag, session, i)

actual_query_tag = session.sql("SHOW PARAMETERS LIKE 'QUERY_TAG'").collect()[0][1]
assert actual_query_tag == session.query_tag


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="SQL query is not supported",
run=False,
)
def test_session_stage_created_once(session):
with patch.object(
session._conn, "run_query", wraps=session._conn.run_query
) as patched_run_query:
with ThreadPoolExecutor(max_workers=10) as executor:
for _ in range(10):
executor.submit(session.get_session_stage)

assert patched_run_query.call_count == 1


def test_action_ids_are_unique(session):
with ThreadPoolExecutor(max_workers=10) as executor:
action_ids = set()
futures = [executor.submit(session._generate_new_action_id) for _ in range(10)]

for future in as_completed(futures):
action_ids.add(future.result())

assert len(action_ids) == 10

0 comments on commit 5f140ab

Please sign in to comment.