Skip to content

Commit

Permalink
make feature flag thread safe
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose committed Dec 17, 2024
1 parent b32806f commit c3db223
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def convert_metadata_to_sp_type(
[
StructField(
field.name
if context._should_use_structured_type_semantics
if context._should_use_structured_type_semantics()
else quote_name(field.name, keep_case=True),
convert_metadata_to_sp_type(field, max_string_size),
nullable=field.is_nullable,
Expand Down Expand Up @@ -188,7 +188,7 @@ def convert_sf_to_sp_type(
) -> DataType:
"""Convert the Snowflake logical type to the Snowpark type."""
semi_structured_fill = (
None if context._should_use_structured_type_semantics else StringType()
None if context._should_use_structured_type_semantics() else StringType()
)
if column_type_name == "ARRAY":
return ArrayType(semi_structured_fill)
Expand Down
16 changes: 14 additions & 2 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Callable, Optional

import snowflake.snowpark
import threading

_use_scoped_temp_objects = True

Expand All @@ -21,8 +22,19 @@
_should_continue_registration: Optional[Callable[..., bool]] = None


# Global flag that determines if structured type semantics should be used
_should_use_structured_type_semantics = False
# Internal-only global flag that determines if structured type semantics should be used
_use_structured_type_semantics = False
_use_structured_type_semantics_lock = None


def _should_use_structured_type_semantics():
global _use_structured_type_semantics
global _use_structured_type_semantics_lock
if _use_structured_type_semantics_lock is None:
_use_structured_type_semantics_lock = threading.RLock()

with _use_structured_type_semantics_lock:
return _use_structured_type_semantics


def get_active_session() -> "snowflake.snowpark.Session":
Expand Down
16 changes: 8 additions & 8 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def __init__(
element_type: Optional[DataType] = None,
structured: Optional[bool] = None,
) -> None:
if context._should_use_structured_type_semantics:
if context._should_use_structured_type_semantics():
self.structured = (
structured if structured is not None else element_type is not None
)
Expand All @@ -349,7 +349,7 @@ def __repr__(self) -> str:
return f"ArrayType({repr(self.element_type) if self.element_type else ''})"

def _as_nested(self) -> "ArrayType":
if not context._should_use_structured_type_semantics:
if not context._should_use_structured_type_semantics():
return self
element_type = self.element_type
if isinstance(element_type, (ArrayType, MapType, StructType)):
Expand Down Expand Up @@ -396,7 +396,7 @@ def __init__(
value_type: Optional[DataType] = None,
structured: Optional[bool] = None,
) -> None:
if context._should_use_structured_type_semantics:
if context._should_use_structured_type_semantics():
if (key_type is None and value_type is not None) or (
key_type is not None and value_type is None
):
Expand All @@ -423,7 +423,7 @@ def is_primitive(self):
return False

def _as_nested(self) -> "MapType":
if not context._should_use_structured_type_semantics:
if not context._should_use_structured_type_semantics():
return self
value_type = self.value_type
if isinstance(value_type, (ArrayType, MapType, StructType)):
Expand Down Expand Up @@ -600,7 +600,7 @@ def __init__(

@property
def name(self) -> str:
if self._is_column or not context._should_use_structured_type_semantics:
if self._is_column or not context._should_use_structured_type_semantics():
return self.column_identifier.name
else:
return self._name
Expand All @@ -615,7 +615,7 @@ def name(self, n: Union[ColumnIdentifier, str]) -> None:
self.column_identifier = ColumnIdentifier(n)

def _as_nested(self) -> "StructField":
if not context._should_use_structured_type_semantics:
if not context._should_use_structured_type_semantics():
return self
datatype = self.datatype
if isinstance(datatype, (ArrayType, MapType, StructType)):
Expand Down Expand Up @@ -677,7 +677,7 @@ def __init__(
fields: Optional[List["StructField"]] = None,
structured: Optional[bool] = False,
) -> None:
if context._should_use_structured_type_semantics:
if context._should_use_structured_type_semantics():
self.structured = (
structured if structured is not None else fields is not None
)
Expand Down Expand Up @@ -713,7 +713,7 @@ def add(
return self

def _as_nested(self) -> "StructType":
if not context._should_use_structured_type_semantics:
if not context._should_use_structured_type_semantics():
return self
return StructType(
[field._as_nested() for field in self.fields], self.structured
Expand Down
11 changes: 6 additions & 5 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,11 @@ def examples(structured_type_support):
def structured_type_session(session, structured_type_support):
if structured_type_support:
with structured_types_enabled_session(session) as sess:
semantics_enabled = context._should_use_structured_type_semantics
context._should_use_structured_type_semantics = True
yield sess
context._should_use_structured_type_semantics = semantics_enabled
semantics_enabled = context._should_use_structured_type_semantics()
with context._use_structured_type_semantics_lock():
context._use_structured_type_semantics = True
yield sess
context._use_structured_type_semantics = semantics_enabled
else:
yield session

Expand Down Expand Up @@ -399,7 +400,7 @@ def test_structured_dtypes_select(
):
query, expected_dtypes, expected_schema = examples
df = _create_test_dataframe(structured_type_session, structured_type_support)
nested_field_name = "b" if context._should_use_structured_type_semantics else "B"
nested_field_name = "b" if context._should_use_structured_type_semantics() else "B"
flattened_df = df.select(
df.map["k1"].alias("value1"),
df.obj["A"].alias("a"),
Expand Down

0 comments on commit c3db223

Please sign in to comment.