Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1418523: concurrent file operations #2288

21 changes: 16 additions & 5 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect
import os
import sys
import threading
import time
from logging import getLogger
from typing import (
Expand Down Expand Up @@ -154,6 +155,8 @@ def __init__(
options: Dict[str, Union[int, str]],
conn: Optional[SnowflakeConnection] = None,
) -> None:
self._lock = threading.RLock()
self._thread_stored = threading.local()
self._lower_case_parameters = {k.lower(): v for k, v in options.items()}
self._add_application_parameters()
self._conn = conn if conn else connect(**self._lower_case_parameters)
Expand All @@ -170,7 +173,6 @@ def __init__(

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
self._cursor = self._conn.cursor()
self._telemetry_client = TelemetryClient(self._conn)
self._query_listener: Set[QueryHistory] = set()
# The session in this case refers to a Snowflake session, not a
Expand All @@ -183,6 +185,12 @@ def __init__(
"_skip_upload_on_content_match" in signature.parameters
)

@property
def _cursor(self) -> SnowflakeCursor:
if not hasattr(self._thread_stored, "cursor"):
self._thread_stored.cursor = self._conn.cursor()
return self._thread_stored.cursor

def _add_application_parameters(self) -> None:
if PARAM_APPLICATION not in self._lower_case_parameters:
# Mirrored from snowflake-connector-python/src/snowflake/connector/connection.py#L295
Expand Down Expand Up @@ -210,10 +218,12 @@ def _add_application_parameters(self) -> None:
] = get_version()

def add_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.add(listener)
with self._lock:
self._query_listener.add(listener)

def remove_query_listener(self, listener: QueryHistory) -> None:
self._query_listener.remove(listener)
with self._lock:
self._query_listener.remove(listener)

def close(self) -> None:
if self._conn:
Expand Down Expand Up @@ -360,8 +370,9 @@ def upload_stream(
raise ex

def notify_query_listeners(self, query_record: QueryRecord) -> None:
for listener in self._query_listener:
listener._add_query(query_record)
with self._lock:
for listener in self._query_listener:
listener._add_query(query_record)

def execute_and_notify_query_listener(
self, query: str, **kwargs: Any
Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,12 +2027,13 @@ def _union_by_name_internal(
]

names = right_project_list + not_found_attrs
if self._session.sql_simplifier_enabled and other._select_statement:
sql_simplifier_enabled = self._session.sql_simplifier_enabled
if sql_simplifier_enabled and other._select_statement:
right_child = self._with_plan(other._select_statement.select(names))
else:
right_child = self._with_plan(Project(names, other._plan))

if self._session.sql_simplifier_enabled:
if sql_simplifier_enabled:
df = self._with_plan(
self._select_statement.set_operator(
right_child._select_statement
Expand Down
62 changes: 37 additions & 25 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,6 @@ def __init__(
)
self._file = FileOperation(self)
self._lineage = Lineage(self)
self._analyzer = (
Analyzer(self) if isinstance(conn, ServerConnection) else MockAnalyzer(self)
)
self._sql_simplifier_enabled: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING, True
Expand Down Expand Up @@ -623,8 +620,19 @@ def __str__(self):
)

def _generate_new_action_id(self) -> int:
self._last_action_id += 1
return self._last_action_id
with self._lock:
self._last_action_id += 1
return self._last_action_id

@property
def _analyzer(self) -> Analyzer:
if not hasattr(self._thread_store, "analyzer"):
self._thread_store.analyzer = (
Analyzer(self)
if isinstance(self._conn, ServerConnection)
else MockAnalyzer(self)
)
return self._thread_store.analyzer

def close(self) -> None:
"""Close this session."""
Expand Down Expand Up @@ -856,7 +864,8 @@ def cancel_all(self) -> None:
This does not affect any action methods called in the future.
"""
_logger.info("Canceling all running queries")
self._last_canceled_id = self._last_action_id
with self._lock:
self._last_canceled_id = self._last_action_id
if not isinstance(self._conn, MockServerConnection):
self._conn.run_query(
f"select system$cancel_all_queries({self._session_id})"
Expand Down Expand Up @@ -1958,11 +1967,12 @@ def query_tag(self) -> Optional[str]:

@query_tag.setter
def query_tag(self, tag: str) -> None:
if tag:
self._conn.run_query(f"alter session set query_tag = {str_to_sql(tag)}")
else:
self._conn.run_query("alter session unset query_tag")
self._query_tag = tag
with self._lock:
if tag:
self._conn.run_query(f"alter session set query_tag = {str_to_sql(tag)}")
else:
self._conn.run_query("alter session unset query_tag")
self._query_tag = tag

def _get_remote_query_tag(self) -> None:
"""
Expand Down Expand Up @@ -2355,18 +2365,19 @@ def get_session_stage(
Therefore, if you switch database or schema during the session, the stage will not be re-created
in the new database or schema, and still references the stage in the old database or schema.
"""
if not self._session_stage:
full_qualified_stage_name = self.get_fully_qualified_name_if_possible(
random_name_for_temp_object(TempObjectType.STAGE)
)
self._run_query(
f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \
stage if not exists {full_qualified_stage_name}",
is_ddl_on_temp_object=True,
statement_params=statement_params,
)
# set the value after running the query to ensure atomicity
self._session_stage = full_qualified_stage_name
with self._lock:
if not self._session_stage:
full_qualified_stage_name = self.get_fully_qualified_name_if_possible(
random_name_for_temp_object(TempObjectType.STAGE)
)
self._run_query(
f"create {get_temp_type_for_object(self._use_scoped_temp_objects, True)} \
stage if not exists {full_qualified_stage_name}",
is_ddl_on_temp_object=True,
statement_params=statement_params,
)
# set the value after running the query to ensure atomicity
self._session_stage = full_qualified_stage_name
return f"{STAGE_PREFIX}{self._session_stage}"

def _write_modin_pandas_helper(
Expand Down Expand Up @@ -3065,8 +3076,9 @@ def get_fully_qualified_name_if_possible(self, name: str) -> str:
"""
Returns the fully qualified object name if current database/schema exists, otherwise returns the object name
"""
database = self.get_current_database()
schema = self.get_current_schema()
with self._lock:
database = self.get_current_database()
schema = self.get_current_schema()
if database and schema:
return f"{database}.{schema}.{name}"

Expand Down
23 changes: 22 additions & 1 deletion tests/integ/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock._connection import MockServerConnection
from tests.parameters import CONNECTION_PARAMETERS
from tests.utils import TEST_SCHEMA, Utils, running_on_jenkins, running_on_public_ci
from tests.utils import (
TEST_SCHEMA,
TestFiles,
Utils,
running_on_jenkins,
running_on_public_ci,
)


def print_help() -> None:
Expand Down Expand Up @@ -235,3 +241,18 @@ def temp_schema(connection, session, local_testing_mode) -> None:
)
yield temp_schema_name
cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}")


@pytest.fixture(scope="module")
def temp_stage(session, resources_path, local_testing_mode):
tmp_stage_name = Utils.random_stage_name()
test_files = TestFiles(resources_path)

if not local_testing_mode:
Utils.create_stage(session, tmp_stage_name, is_temporary=True)
Utils.upload_to_stage(
session, tmp_stage_name, test_files.test_file_parquet, compress=False
)
yield tmp_stage_name
if not local_testing_mode:
Utils.drop_stage(session, tmp_stage_name)
17 changes: 1 addition & 16 deletions tests/integ/scala/test_file_operation_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SnowparkSQLException,
SnowparkUploadFileException,
)
from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, TestFiles, Utils
from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, Utils


def random_alphanumeric_name():
Expand Down Expand Up @@ -74,21 +74,6 @@ def path4(temp_source_directory):
yield filename


@pytest.fixture(scope="module")
def temp_stage(session, resources_path, local_testing_mode):
tmp_stage_name = Utils.random_stage_name()
test_files = TestFiles(resources_path)

if not local_testing_mode:
Utils.create_stage(session, tmp_stage_name, is_temporary=True)
Utils.upload_to_stage(
session, tmp_stage_name, test_files.test_file_parquet, compress=False
)
yield tmp_stage_name
if not local_testing_mode:
Utils.drop_stage(session, tmp_stage_name)


def test_put_with_one_file(
session, temp_stage, path1, path2, path3, local_testing_mode
):
Expand Down
Loading
Loading