Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1663726 make session config updates thread safe #2302

Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
56fb566
init
sfc-gh-aalam Sep 11, 2024
66003d1
make udf/sproc related files thread-safe
sfc-gh-aalam Sep 11, 2024
0e58205
Merge branch 'main' into aalam-SNOW-1418523-make-udf-sproc-thread-safe
sfc-gh-aalam Sep 11, 2024
e75dde1
init
sfc-gh-aalam Sep 11, 2024
68a8c1c
make query listener thread-safe
sfc-gh-aalam Sep 11, 2024
31a5734
Fix query_tag and last_action_id
sfc-gh-aalam Sep 11, 2024
b4dadda
core updates done
sfc-gh-aalam Sep 11, 2024
b8c6496
Add tests
sfc-gh-aalam Sep 12, 2024
f39837e
Fix local tests
sfc-gh-aalam Sep 12, 2024
31a196f
Merge branch 'main' into aalam-SNOW-1418523-make-analyzer-server_conn…
sfc-gh-aalam Sep 12, 2024
723bdf7
Merge branch 'aalam-SNOW-1418523-make-internal-session-variables-thre…
sfc-gh-aalam Sep 12, 2024
37c0419
add file IO tests
sfc-gh-aalam Sep 12, 2024
8a2d433
Merge branch 'aalam-SNOW-1418523-concurrent-file-operations' into aal…
sfc-gh-aalam Sep 12, 2024
a083989
make session._runtime_version_from_requirement safe
sfc-gh-aalam Sep 13, 2024
947d384
add sp/udf concurrent tests
sfc-gh-aalam Sep 13, 2024
fd51720
fix broken test
sfc-gh-aalam Sep 13, 2024
3077853
add udtf/udaf tests
sfc-gh-aalam Sep 13, 2024
65c3186
fix broken test
sfc-gh-aalam Sep 13, 2024
94412cf
sql_simplifier, cte_optimization, eliminate_numeric, query_compilatio…
sfc-gh-aalam Sep 13, 2024
638dd09
cover more configs
sfc-gh-aalam Sep 17, 2024
7ae2c33
fix SnowflakePlan copy
sfc-gh-aalam Sep 17, 2024
1689ebf
minor update
sfc-gh-aalam Sep 17, 2024
5e8a2d2
add description
sfc-gh-aalam Sep 17, 2024
1c83ef2
use _package_lock to protect Session._packages
sfc-gh-aalam Sep 17, 2024
a649761
undo refactor
sfc-gh-aalam Sep 17, 2024
f03d618
undo refactor
sfc-gh-aalam Sep 17, 2024
5f398d5
fix test
sfc-gh-aalam Sep 17, 2024
3807087
fix test
sfc-gh-aalam Sep 17, 2024
4eef3e9
Merge branch 'aalam-SNOW-1418523-make-internal-session-variables-thre…
sfc-gh-aalam Sep 17, 2024
df3263c
add file IO tests
sfc-gh-aalam Sep 12, 2024
6769c54
merge with base
sfc-gh-aalam Sep 17, 2024
af86f67
merge with base
sfc-gh-aalam Sep 17, 2024
a737f33
fix test
sfc-gh-aalam Sep 17, 2024
8ca2730
protect complexity bounds setter with lock
sfc-gh-aalam Sep 17, 2024
81417a3
add config context
sfc-gh-aalam Sep 19, 2024
e340567
add tests
sfc-gh-aalam Sep 19, 2024
30952bb
update documentation
sfc-gh-aalam Sep 20, 2024
03f25b5
use config context in plan compiler
sfc-gh-aalam Sep 20, 2024
6deb402
add comments
sfc-gh-aalam Sep 20, 2024
8e1dfe0
minor refactor
sfc-gh-aalam Sep 20, 2024
10bfeb4
fix test
sfc-gh-aalam Sep 20, 2024
879940a
update documentation
sfc-gh-aalam Sep 20, 2024
5aad2d9
simplify context config
sfc-gh-aalam Sep 25, 2024
669eb91
merge with base
sfc-gh-aalam Sep 25, 2024
a85a144
add config context to repeated subquery elimination resolution stage
sfc-gh-aalam Sep 25, 2024
a79ffb4
fix tests
sfc-gh-aalam Sep 26, 2024
4420350
refactor
sfc-gh-aalam Sep 26, 2024
5f1eaa6
remove do_analyze
sfc-gh-aalam Sep 27, 2024
9d62017
fix
sfc-gh-aalam Sep 27, 2024
b58aa8b
fix
sfc-gh-aalam Sep 27, 2024
db37033
fix
sfc-gh-aalam Sep 27, 2024
dddd15f
fix unit tests
sfc-gh-aalam Sep 27, 2024
57ee9e8
simplify
sfc-gh-aalam Sep 27, 2024
809a86e
simplify
sfc-gh-aalam Sep 27, 2024
6021ab8
simplify
sfc-gh-aalam Sep 27, 2024
43986f6
simplify
sfc-gh-aalam Sep 27, 2024
0430e92
simplify
sfc-gh-aalam Sep 27, 2024
095b04e
remove config context
sfc-gh-aalam Sep 30, 2024
32707f9
min-diff
sfc-gh-aalam Sep 30, 2024
3bf678d
min-diff
sfc-gh-aalam Sep 30, 2024
3eade1a
min-diff
sfc-gh-aalam Sep 30, 2024
1850d5d
Merge branch 'aalam-SNOW-1418523-make-internal-session-variables-thre…
sfc-gh-aalam Oct 2, 2024
1fa6ad2
add warnings
sfc-gh-aalam Oct 2, 2024
7c85432
Merge branch 'aalam-SNOW-1418523-make-internal-session-variables-thre…
sfc-gh-aalam Oct 3, 2024
f994842
address feedback
sfc-gh-aalam Oct 3, 2024
4621836
address feedback
sfc-gh-aalam Oct 3, 2024
e1c68f3
fix string
sfc-gh-aalam Oct 3, 2024
e5b48dd
ignore on multi-thread
sfc-gh-aalam Oct 3, 2024
496e2be
undo ignore
sfc-gh-aalam Oct 3, 2024
980d3b7
update warning message
sfc-gh-aalam Oct 4, 2024
54a6b5d
address comments
sfc-gh-aalam Oct 4, 2024
67609e8
address comments
sfc-gh-aalam Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,14 @@ def __init__(
session: Session,
query_generator: QueryGenerator,
logical_plans: List[LogicalPlan],
complexity_bounds: Tuple[int, int],
) -> None:
self.session = session
self._query_generator = query_generator
self.logical_plans = logical_plans
self._parent_map = defaultdict(set)
self.complexity_score_lower_bound = (
session.large_query_breakdown_complexity_bounds[0]
)
self.complexity_score_upper_bound = (
session.large_query_breakdown_complexity_bounds[1]
)
self.complexity_score_lower_bound = complexity_bounds[0]
self.complexity_score_upper_bound = complexity_bounds[1]

def apply(self) -> List[LogicalPlan]:
if is_active_transaction(self.session):
Expand Down
34 changes: 20 additions & 14 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class PlanCompiler:

def __init__(self, plan: SnowflakePlan) -> None:
self._plan = plan
session = plan.session
self.cte_optimization_enabled = session.cte_optimization_enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are restricting the change to config runtime, this is not necessary anymore, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are only putting restriction on cte_optimization_enabled. Since for other, we still need to take a snapshot, I'm taking a snapshot for cte param for completeness.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since now we are putting protection on all config variables, i think we don't need to do that anymore

self.large_query_breakdown_enabled = session.large_query_breakdown_enabled
self.large_query_breakdown_complexity_bounds = (
session.large_query_breakdown_complexity_bounds
)
self.query_compilation_stage_enabled = session._query_compilation_stage_enabled

def should_start_query_compilation(self) -> bool:
"""
Expand All @@ -68,15 +75,13 @@ def should_start_query_compilation(self) -> bool:
return (
not isinstance(current_session._conn, MockServerConnection)
and (self._plan.source_plan is not None)
and current_session._query_compilation_stage_enabled
and (
current_session.cte_optimization_enabled
or current_session.large_query_breakdown_enabled
)
and self.query_compilation_stage_enabled
and (self.cte_optimization_enabled or self.large_query_breakdown_enabled)
)

def compile(self) -> Dict[PlanQueryType, List[Query]]:
if self.should_start_query_compilation():
session = self._plan.session
# preparation for compilation
# 1. make a copy of the original plan
start_time = time.time()
Expand All @@ -93,7 +98,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
# 3. apply each optimizations if needed
# CTE optimization
cte_start_time = time.time()
if self._plan.session.cte_optimization_enabled:
if self.cte_optimization_enabled:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
Expand All @@ -108,9 +113,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}")

# Large query breakdown
if self._plan.session.large_query_breakdown_enabled:
if self.large_query_breakdown_enabled:
large_query_breakdown = LargeQueryBreakdown(
self._plan.session, query_generator, logical_plans
session,
query_generator,
logical_plans,
self.large_query_breakdown_complexity_bounds,
)
logical_plans = large_query_breakdown.apply()

Expand All @@ -130,11 +138,10 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
cte_time = cte_end_time - cte_start_time
large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time
total_time = time.time() - start_time
session = self._plan.session
summary_value = {
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds,
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.cte_optimization_enabled,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.large_query_breakdown_enabled,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self.large_query_breakdown_complexity_bounds,
CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time,
Expand All @@ -151,8 +158,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
return queries
else:
final_plan = self._plan
if self._plan.session.cte_optimization_enabled:
final_plan = final_plan.replace_repeated_subquery_with_cte()
final_plan = final_plan.replace_repeated_subquery_with_cte()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that intended? i think we still need the condition here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. replace_repeated_subquery_with_cte() does this check automatically.

return {
PlanQueryType.QUERIES: final_plan.queries,
PlanQueryType.POST_ACTIONS: final_plan.post_actions,
Expand Down
136 changes: 81 additions & 55 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,38 +339,44 @@ def __init__(self, session: "Session", conf: Dict[str, Any]) -> None:
"use_constant_subquery_alias": True,
"flatten_select_after_filter_and_orderby": True,
} # For config that's temporary/to be removed soon
self._lock = self._session._lock
for key, val in conf.items():
if self.is_mutable(key):
self.set(key, val)

def get(self, key: str, default=None) -> Any:
if hasattr(Session, key):
return getattr(self._session, key)
if hasattr(self._session._conn._conn, key):
return getattr(self._session._conn._conn, key)
return self._conf.get(key, default)
with self._lock:
if hasattr(Session, key):
return getattr(self._session, key)
if hasattr(self._session._conn._conn, key):
return getattr(self._session._conn._conn, key)
return self._conf.get(key, default)

def is_mutable(self, key: str) -> bool:
if hasattr(Session, key) and isinstance(getattr(Session, key), property):
return getattr(Session, key).fset is not None
if hasattr(SnowflakeConnection, key) and isinstance(
getattr(SnowflakeConnection, key), property
):
return getattr(SnowflakeConnection, key).fset is not None
return key in self._conf
with self._lock:
if hasattr(Session, key) and isinstance(
getattr(Session, key), property
):
return getattr(Session, key).fset is not None
if hasattr(SnowflakeConnection, key) and isinstance(
getattr(SnowflakeConnection, key), property
):
return getattr(SnowflakeConnection, key).fset is not None
return key in self._conf

def set(self, key: str, value: Any) -> None:
if self.is_mutable(key):
if hasattr(Session, key):
setattr(self._session, key, value)
if hasattr(SnowflakeConnection, key):
setattr(self._session._conn._conn, key, value)
if key in self._conf:
self._conf[key] = value
else:
raise AttributeError(
f'Configuration "{key}" does not exist or is not mutable in runtime'
)
with self._lock:
if self.is_mutable(key):
if hasattr(Session, key):
setattr(self._session, key, value)
if hasattr(SnowflakeConnection, key):
setattr(self._session._conn._conn, key, value)
if key in self._conf:
self._conf[key] = value
else:
raise AttributeError(
f'Configuration "{key}" does not exist or is not mutable in runtime'
)

class SessionBuilder:
"""
Expand Down Expand Up @@ -538,11 +544,6 @@ def __init__(
self._udtf_registration = UDTFRegistration(self)
self._udaf_registration = UDAFRegistration(self)

self._plan_builder = (
SnowflakePlanBuilder(self)
if isinstance(self._conn, ServerConnection)
else MockSnowflakePlanBuilder(self)
)
self._last_action_id = 0
self._last_canceled_id = 0
self._use_scoped_temp_objects: bool = (
Expand Down Expand Up @@ -637,6 +638,16 @@ def _analyzer(self) -> Analyzer:
)
return self._thread_store.analyzer

@property
def _plan_builder(self):
if not hasattr(self._thread_store, "plan_builder"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i didn't notice we have plan builder here, i thought plan builder is just part of the analyzer, what is the plan builder used here?, and what is the reason that we moved it to a property

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the plan builder seems thread local to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This plan builder is different from analyzer's plan builder. Only used by

  1. DataFramReader to create read_file plan
  2. FileOperation to create file_operation_plan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it as part of old changes. reverted back to original now since plan builder should be thread-safe with restricted config updates

self._thread_store.plan_builder = (
SnowflakePlanBuilder(self)
if isinstance(self._conn, ServerConnection)
else MockSnowflakePlanBuilder(self)
)
return self._thread_store.plan_builder

def close(self) -> None:
"""Close this session."""
if is_in_stored_procedure():
Expand Down Expand Up @@ -770,36 +781,45 @@ def custom_package_usage_config(self) -> Dict:

@sql_simplifier_enabled.setter
def sql_simplifier_enabled(self, value: bool) -> None:
self._conn._telemetry_client.send_sql_simplifier_telemetry(
self._session_id, value
)
try:
self._conn._cursor.execute(
f"alter session set {_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING} = {value}"
with self._lock:
self._conn._telemetry_client.send_sql_simplifier_telemetry(
self._session_id, value
)
except Exception:
pass
self._sql_simplifier_enabled = value
try:
self._conn._cursor.execute(
f"alter session set {_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING} = {value}"
)
except Exception:
pass
self._sql_simplifier_enabled = value

@cte_optimization_enabled.setter
@experimental_parameter(version="1.15.0")
def cte_optimization_enabled(self, value: bool) -> None:
if value:
self._conn._telemetry_client.send_cte_optimization_telemetry(
self._session_id
if threading.active_count() > 1:
# TODO (SNOW-1541096): Remove the limitation once old cte implementation is removed.
_logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we could simply disallow updating of any config value when there are multiple active thread to be consistent everywhere

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, i actually uses large_query_breakdown when deciding whether the complexity can be merged or not.

"Setting cte_optimization_enabled is not currently thread-safe. Ignoring the update"
)
self._cte_optimization_enabled = value
return
with self._lock:
if value:
self._conn._telemetry_client.send_cte_optimization_telemetry(
self._session_id
)
self._cte_optimization_enabled = value

@eliminate_numeric_sql_value_cast_enabled.setter
@experimental_parameter(version="1.20.0")
def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None:
"""Set the value for eliminate_numeric_sql_value_cast_enabled"""

if value in [True, False]:
self._conn._telemetry_client.send_eliminate_numeric_sql_value_cast_telemetry(
self._session_id, value
)
self._eliminate_numeric_sql_value_cast_enabled = value
with self._lock:
self._conn._telemetry_client.send_eliminate_numeric_sql_value_cast_telemetry(
self._session_id, value
)
self._eliminate_numeric_sql_value_cast_enabled = value
else:
raise ValueError(
"value for eliminate_numeric_sql_value_cast_enabled must be True or False!"
Expand Down Expand Up @@ -829,10 +849,11 @@ def large_query_breakdown_enabled(self, value: bool) -> None:
"""

if value in [True, False]:
self._conn._telemetry_client.send_large_query_breakdown_telemetry(
self._session_id, value
)
self._large_query_breakdown_enabled = value
with self._lock:
self._conn._telemetry_client.send_large_query_breakdown_telemetry(
self._session_id, value
)
self._large_query_breakdown_enabled = value
else:
raise ValueError(
"value for large_query_breakdown_enabled must be True or False!"
Expand All @@ -850,16 +871,20 @@ def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> Non
raise ValueError(
f"Expecting a tuple of lower and upper bound with the lower bound less than the upper bound. Got (lower, upper) = ({value[0], value[1]})"
)
self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds(
self._session_id, value[0], value[1]
)
with self._lock:
self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds(
self._session_id, value[0], value[1]
)

self._large_query_breakdown_complexity_bounds = value
self._large_query_breakdown_complexity_bounds = value

@custom_package_usage_config.setter
@experimental_parameter(version="1.6.0")
def custom_package_usage_config(self, config: Dict) -> None:
self._custom_package_usage_config = {k.lower(): v for k, v in config.items()}
with self._lock:
self._custom_package_usage_config = {
k.lower(): v for k, v in config.items()
}

def cancel_all(self) -> None:
"""
Expand Down Expand Up @@ -1458,7 +1483,8 @@ def _get_dependency_packages(
statement_params=statement_params,
)

custom_package_usage_config = self._custom_package_usage_config.copy()
with self._lock:
custom_package_usage_config = self._custom_package_usage_config.copy()

unsupported_packages: List[str] = []
for package, package_info in package_dict.items():
Expand Down
18 changes: 18 additions & 0 deletions tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

import hashlib
import logging
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
Expand Down Expand Up @@ -502,3 +503,20 @@ def finish(self):
with ThreadPoolExecutor(max_workers=10) as executor:
for i in range(10):
executor.submit(register_and_test_udaf, session, i)


def test_concurrent_update_on_cte_optimization_enabled(session, caplog):
def run_cte_optimization(session_, thread_id):
if thread_id % 2 == 0:
session_.cte_optimization_enabled = True
else:
session_.cte_optimization_enabled = False

caplog.clear()
with caplog.at_level(logging.WARNING):
with ThreadPoolExecutor(max_workers=5) as executor:
for i in range(5):
executor.submit(run_cte_optimization, session, i)
assert (
"Setting cte_optimization_enabled is not currently thread-safe" in caplog.text
)
7 changes: 6 additions & 1 deletion tests/unit/compiler/test_large_query_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@
],
)
def test_pipeline_breaker_node(mock_session, mock_analyzer, node_generator, expected):
large_query_breakdown = LargeQueryBreakdown(mock_session, mock_analyzer, [])
large_query_breakdown = LargeQueryBreakdown(
mock_session,
mock_analyzer,
[],
mock_session.large_query_breakdown_complexity_bounds,
)
node = node_generator(mock_analyzer)

assert (
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_get_result_set_exception(mock_server_connection):
fake_session._last_canceled_id = 100
fake_session._conn = mock_server_connection
fake_session._cte_optimization_enabled = False
fake_session._query_compilation_stage_enabled = False
fake_plan = SnowflakePlan(
queries=[Query("fake query 1"), Query("fake query 2")],
schema_query="fake schema query",
Expand Down
Loading