-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-1663726 make session config updates thread safe #2302
Changes from 65 commits
56fb566
66003d1
0e58205
e75dde1
68a8c1c
31a5734
b4dadda
b8c6496
f39837e
31a196f
723bdf7
37c0419
8a2d433
a083989
947d384
fd51720
3077853
65c3186
94412cf
638dd09
7ae2c33
1689ebf
5e8a2d2
1c83ef2
a649761
f03d618
5f398d5
3807087
4eef3e9
df3263c
6769c54
af86f67
a737f33
8ca2730
81417a3
e340567
30952bb
03f25b5
6deb402
8e1dfe0
10bfeb4
879940a
5aad2d9
669eb91
a85a144
a79ffb4
4420350
5f1eaa6
9d62017
b58aa8b
db37033
dddd15f
57ee9e8
809a86e
6021ab8
43986f6
0430e92
095b04e
32707f9
3bf678d
3eade1a
1850d5d
1fa6ad2
7c85432
f994842
4621836
e1c68f3
e5b48dd
496e2be
980d3b7
54a6b5d
67609e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,6 +49,13 @@ class PlanCompiler: | |
|
||
def __init__(self, plan: SnowflakePlan) -> None: | ||
self._plan = plan | ||
session = plan.session | ||
self.cte_optimization_enabled = session.cte_optimization_enabled | ||
self.large_query_breakdown_enabled = session.large_query_breakdown_enabled | ||
self.large_query_breakdown_complexity_bounds = ( | ||
session.large_query_breakdown_complexity_bounds | ||
) | ||
self.query_compilation_stage_enabled = session._query_compilation_stage_enabled | ||
|
||
def should_start_query_compilation(self) -> bool: | ||
""" | ||
|
@@ -68,15 +75,13 @@ def should_start_query_compilation(self) -> bool: | |
return ( | ||
not isinstance(current_session._conn, MockServerConnection) | ||
and (self._plan.source_plan is not None) | ||
and current_session._query_compilation_stage_enabled | ||
and ( | ||
current_session.cte_optimization_enabled | ||
or current_session.large_query_breakdown_enabled | ||
) | ||
and self.query_compilation_stage_enabled | ||
and (self.cte_optimization_enabled or self.large_query_breakdown_enabled) | ||
) | ||
|
||
def compile(self) -> Dict[PlanQueryType, List[Query]]: | ||
if self.should_start_query_compilation(): | ||
session = self._plan.session | ||
# preparation for compilation | ||
# 1. make a copy of the original plan | ||
start_time = time.time() | ||
|
@@ -93,7 +98,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: | |
# 3. apply each optimizations if needed | ||
# CTE optimization | ||
cte_start_time = time.time() | ||
if self._plan.session.cte_optimization_enabled: | ||
if self.cte_optimization_enabled: | ||
repeated_subquery_eliminator = RepeatedSubqueryElimination( | ||
logical_plans, query_generator | ||
) | ||
|
@@ -108,9 +113,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: | |
plot_plan_if_enabled(plan, f"cte_optimized_plan_{i}") | ||
|
||
# Large query breakdown | ||
if self._plan.session.large_query_breakdown_enabled: | ||
if self.large_query_breakdown_enabled: | ||
large_query_breakdown = LargeQueryBreakdown( | ||
self._plan.session, query_generator, logical_plans | ||
session, | ||
query_generator, | ||
logical_plans, | ||
self.large_query_breakdown_complexity_bounds, | ||
) | ||
logical_plans = large_query_breakdown.apply() | ||
|
||
|
@@ -130,11 +138,10 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: | |
cte_time = cte_end_time - cte_start_time | ||
large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time | ||
total_time = time.time() - start_time | ||
session = self._plan.session | ||
summary_value = { | ||
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, | ||
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, | ||
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, | ||
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.cte_optimization_enabled, | ||
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.large_query_breakdown_enabled, | ||
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self.large_query_breakdown_complexity_bounds, | ||
CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, | ||
CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, | ||
CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, | ||
|
@@ -151,8 +158,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: | |
return queries | ||
else: | ||
final_plan = self._plan | ||
if self._plan.session.cte_optimization_enabled: | ||
final_plan = final_plan.replace_repeated_subquery_with_cte() | ||
final_plan = final_plan.replace_repeated_subquery_with_cte() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is that intended? i think we still need the condition here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. |
||
return { | ||
PlanQueryType.QUERIES: final_plan.queries, | ||
PlanQueryType.POST_ACTIONS: final_plan.post_actions, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -339,38 +339,44 @@ def __init__(self, session: "Session", conf: Dict[str, Any]) -> None: | |
"use_constant_subquery_alias": True, | ||
"flatten_select_after_filter_and_orderby": True, | ||
} # For config that's temporary/to be removed soon | ||
self._lock = self._session._lock | ||
for key, val in conf.items(): | ||
if self.is_mutable(key): | ||
self.set(key, val) | ||
|
||
def get(self, key: str, default=None) -> Any: | ||
if hasattr(Session, key): | ||
return getattr(self._session, key) | ||
if hasattr(self._session._conn._conn, key): | ||
return getattr(self._session._conn._conn, key) | ||
return self._conf.get(key, default) | ||
with self._lock: | ||
if hasattr(Session, key): | ||
return getattr(self._session, key) | ||
if hasattr(self._session._conn._conn, key): | ||
return getattr(self._session._conn._conn, key) | ||
return self._conf.get(key, default) | ||
|
||
def is_mutable(self, key: str) -> bool: | ||
if hasattr(Session, key) and isinstance(getattr(Session, key), property): | ||
return getattr(Session, key).fset is not None | ||
if hasattr(SnowflakeConnection, key) and isinstance( | ||
getattr(SnowflakeConnection, key), property | ||
): | ||
return getattr(SnowflakeConnection, key).fset is not None | ||
return key in self._conf | ||
with self._lock: | ||
if hasattr(Session, key) and isinstance( | ||
getattr(Session, key), property | ||
): | ||
return getattr(Session, key).fset is not None | ||
if hasattr(SnowflakeConnection, key) and isinstance( | ||
getattr(SnowflakeConnection, key), property | ||
): | ||
return getattr(SnowflakeConnection, key).fset is not None | ||
return key in self._conf | ||
|
||
def set(self, key: str, value: Any) -> None: | ||
if self.is_mutable(key): | ||
if hasattr(Session, key): | ||
setattr(self._session, key, value) | ||
if hasattr(SnowflakeConnection, key): | ||
setattr(self._session._conn._conn, key, value) | ||
if key in self._conf: | ||
self._conf[key] = value | ||
else: | ||
raise AttributeError( | ||
f'Configuration "{key}" does not exist or is not mutable in runtime' | ||
) | ||
with self._lock: | ||
if self.is_mutable(key): | ||
if hasattr(Session, key): | ||
setattr(self._session, key, value) | ||
if hasattr(SnowflakeConnection, key): | ||
setattr(self._session._conn._conn, key, value) | ||
if key in self._conf: | ||
self._conf[key] = value | ||
else: | ||
raise AttributeError( | ||
f'Configuration "{key}" does not exist or is not mutable in runtime' | ||
) | ||
|
||
class SessionBuilder: | ||
""" | ||
|
@@ -538,11 +544,6 @@ def __init__( | |
self._udtf_registration = UDTFRegistration(self) | ||
self._udaf_registration = UDAFRegistration(self) | ||
|
||
self._plan_builder = ( | ||
SnowflakePlanBuilder(self) | ||
if isinstance(self._conn, ServerConnection) | ||
else MockSnowflakePlanBuilder(self) | ||
) | ||
self._last_action_id = 0 | ||
self._last_canceled_id = 0 | ||
self._use_scoped_temp_objects: bool = ( | ||
|
@@ -637,6 +638,16 @@ def _analyzer(self) -> Analyzer: | |
) | ||
return self._thread_store.analyzer | ||
|
||
@property | ||
def _plan_builder(self): | ||
if not hasattr(self._thread_store, "plan_builder"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, i didn't notice we have plan builder here, i thought plan builder is just part of the analyzer, what is the plan builder used here?, and what is the reason that we moved it to a property There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and the plan builder seems thread local to me There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This plan builder is different from analyzer's plan builder. Only used by
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added it as part of old changes. reverted back to original now since plan builder should be thread-safe with restricted config updates |
||
self._thread_store.plan_builder = ( | ||
SnowflakePlanBuilder(self) | ||
if isinstance(self._conn, ServerConnection) | ||
else MockSnowflakePlanBuilder(self) | ||
) | ||
return self._thread_store.plan_builder | ||
|
||
def close(self) -> None: | ||
"""Close this session.""" | ||
if is_in_stored_procedure(): | ||
|
@@ -770,36 +781,59 @@ def custom_package_usage_config(self) -> Dict: | |
|
||
@sql_simplifier_enabled.setter | ||
def sql_simplifier_enabled(self, value: bool) -> None: | ||
self._conn._telemetry_client.send_sql_simplifier_telemetry( | ||
self._session_id, value | ||
) | ||
try: | ||
self._conn._cursor.execute( | ||
f"alter session set {_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING} = {value}" | ||
if threading.active_count() > 1: | ||
_logger.warning( | ||
"Setting sql_simplifier_enabled is not currently thread-safe. " | ||
"Ignoring the update" | ||
) | ||
except Exception: | ||
pass | ||
self._sql_simplifier_enabled = value | ||
return | ||
|
||
with self._lock: | ||
self._conn._telemetry_client.send_sql_simplifier_telemetry( | ||
self._session_id, value | ||
) | ||
try: | ||
self._conn._cursor.execute( | ||
f"alter session set {_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING} = {value}" | ||
) | ||
except Exception: | ||
pass | ||
self._sql_simplifier_enabled = value | ||
|
||
@cte_optimization_enabled.setter | ||
@experimental_parameter(version="1.15.0") | ||
def cte_optimization_enabled(self, value: bool) -> None: | ||
if value: | ||
self._conn._telemetry_client.send_cte_optimization_telemetry( | ||
self._session_id | ||
if threading.active_count() > 1: | ||
# TODO (SNOW-1541096): Remove the limitation once old cte implementation is removed. | ||
_logger.warning( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we could simply disallow updating of any config value when there are multiple active thread to be consistent everywhere There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, i actually uses large_query_breakdown when deciding whether the complexity can be merged or not. |
||
"Setting cte_optimization_enabled is not currently thread-safe. Ignoring the update" | ||
) | ||
self._cte_optimization_enabled = value | ||
return | ||
|
||
with self._lock: | ||
if value: | ||
self._conn._telemetry_client.send_cte_optimization_telemetry( | ||
self._session_id | ||
) | ||
self._cte_optimization_enabled = value | ||
|
||
@eliminate_numeric_sql_value_cast_enabled.setter | ||
@experimental_parameter(version="1.20.0") | ||
def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: | ||
"""Set the value for eliminate_numeric_sql_value_cast_enabled""" | ||
if threading.active_count() > 1: | ||
_logger.warning( | ||
"Setting eliminate_numeric_sql_value_cast_enabled is not currently thread-safe. " | ||
"Ignoring the update" | ||
) | ||
return | ||
|
||
if value in [True, False]: | ||
self._conn._telemetry_client.send_eliminate_numeric_sql_value_cast_telemetry( | ||
self._session_id, value | ||
) | ||
self._eliminate_numeric_sql_value_cast_enabled = value | ||
with self._lock: | ||
self._conn._telemetry_client.send_eliminate_numeric_sql_value_cast_telemetry( | ||
self._session_id, value | ||
) | ||
self._eliminate_numeric_sql_value_cast_enabled = value | ||
else: | ||
raise ValueError( | ||
"value for eliminate_numeric_sql_value_cast_enabled must be True or False!" | ||
|
@@ -809,6 +843,13 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: | |
@experimental_parameter(version="1.21.0") | ||
def auto_clean_up_temp_table_enabled(self, value: bool) -> None: | ||
"""Set the value for auto_clean_up_temp_table_enabled""" | ||
if threading.active_count() > 1: | ||
_logger.warning( | ||
"Setting auto_clean_up_temp_table_enabled is not currently thread-safe. " | ||
"Ignoring the update" | ||
) | ||
return | ||
|
||
if value in [True, False]: | ||
self._conn._telemetry_client.send_auto_clean_up_temp_table_telemetry( | ||
self._session_id, value | ||
|
@@ -827,12 +868,18 @@ def large_query_breakdown_enabled(self, value: bool) -> None: | |
materialize the partitions, and then combine them to execute the query to improve | ||
overall performance. | ||
""" | ||
if threading.active_count() > 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can probably have a small utility function and call that for all configs with the config name as a parameter |
||
_logger.warning( | ||
"Setting large_query_breakdown_enabled is not currently thread-safe. Ignoring the update" | ||
) | ||
return | ||
|
||
if value in [True, False]: | ||
self._conn._telemetry_client.send_large_query_breakdown_telemetry( | ||
self._session_id, value | ||
) | ||
self._large_query_breakdown_enabled = value | ||
with self._lock: | ||
self._conn._telemetry_client.send_large_query_breakdown_telemetry( | ||
self._session_id, value | ||
) | ||
self._large_query_breakdown_enabled = value | ||
else: | ||
raise ValueError( | ||
"value for large_query_breakdown_enabled must be True or False!" | ||
|
@@ -841,6 +888,12 @@ def large_query_breakdown_enabled(self, value: bool) -> None: | |
@large_query_breakdown_complexity_bounds.setter | ||
def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: | ||
"""Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" | ||
if threading.active_count() > 1: | ||
_logger.warning( | ||
"Setting large_query_breakdown_complexity_bounds is not currently thread-safe. " | ||
"Ignoring the update" | ||
) | ||
return | ||
|
||
if len(value) != 2: | ||
raise ValueError( | ||
|
@@ -850,16 +903,20 @@ def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> Non | |
raise ValueError( | ||
f"Expecting a tuple of lower and upper bound with the lower bound less than the upper bound. Got (lower, upper) = ({value[0], value[1]})" | ||
) | ||
self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds( | ||
self._session_id, value[0], value[1] | ||
) | ||
with self._lock: | ||
self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds( | ||
self._session_id, value[0], value[1] | ||
) | ||
|
||
self._large_query_breakdown_complexity_bounds = value | ||
self._large_query_breakdown_complexity_bounds = value | ||
|
||
@custom_package_usage_config.setter | ||
@experimental_parameter(version="1.6.0") | ||
def custom_package_usage_config(self, config: Dict) -> None: | ||
self._custom_package_usage_config = {k.lower(): v for k, v in config.items()} | ||
with self._lock: | ||
self._custom_package_usage_config = { | ||
k.lower(): v for k, v in config.items() | ||
} | ||
|
||
def cancel_all(self) -> None: | ||
""" | ||
|
@@ -1458,7 +1515,8 @@ def _get_dependency_packages( | |
statement_params=statement_params, | ||
) | ||
|
||
custom_package_usage_config = self._custom_package_usage_config.copy() | ||
with self._lock: | ||
custom_package_usage_config = self._custom_package_usage_config.copy() | ||
|
||
unsupported_packages: List[str] = [] | ||
for package, package_info in package_dict.items(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we are restricting the change to config runtime, this is not necessary anymore, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are only putting restriction on
cte_optimization_enabled
. Since for other, we still need to take a snapshot, I'm taking a snapshot for cte param for completeness.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since now we are putting protection on all config variables, i think we don't need to do that anymore