Skip to content

Commit

Permalink
Move flag to context
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose committed Dec 16, 2024
1 parent 2e0dce9 commit ed232de
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 21 deletions.
6 changes: 4 additions & 2 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_origin,
)

import snowflake.snowpark.context as context
import snowflake.snowpark.types # type: ignore
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.connector.cursor import ResultMetadata
Expand Down Expand Up @@ -70,7 +71,6 @@
_FractionalType,
_IntegralType,
_NumericType,
STRUCTURED_TYPES_ENABLED,
)

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
Expand Down Expand Up @@ -184,7 +184,9 @@ def convert_sf_to_sp_type(
max_string_size: int,
) -> DataType:
"""Convert the Snowflake logical type to the Snowpark type."""
semi_structured_fill = None if STRUCTURED_TYPES_ENABLED else StringType()
semi_structured_fill = (
None if context._should_use_structured_type_semanticselse else StringType()
)
if column_type_name == "ARRAY":
return ArrayType(semi_structured_fill)
if column_type_name == "VARIANT":
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
_should_continue_registration: Optional[Callable[..., bool]] = None


# Global flag that determines if structured type semantics should be used
_should_use_structured_type_semantics = False


def get_active_session() -> "snowflake.snowpark.Session":
"""Returns the current active Snowpark session.
Expand Down
10 changes: 4 additions & 6 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from enum import Enum
from typing import Generic, List, Optional, Type, TypeVar, Union, Dict, Any

import snowflake.snowpark.context as context
import snowflake.snowpark._internal.analyzer.expression as expression
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

Expand All @@ -31,9 +32,6 @@
from collections.abc import Iterable


STRUCTURED_TYPES_ENABLED = False


class DataType:
"""The base class of Snowpark data types."""

Expand Down Expand Up @@ -338,7 +336,7 @@ def __init__(
element_type: Optional[DataType] = None,
structured: Optional[bool] = None,
) -> None:
if STRUCTURED_TYPES_ENABLED:
if context._should_use_structured_type_semantics:
self.structured = (
structured if structured is not None else element_type is not None
)
Expand Down Expand Up @@ -390,7 +388,7 @@ def __init__(
value_type: Optional[DataType] = None,
structured: Optional[bool] = None,
) -> None:
if STRUCTURED_TYPES_ENABLED:
if context._should_use_structured_type_semantics:
if (key_type is None and value_type is not None) or (
key_type is not None and value_type is None
):
Expand Down Expand Up @@ -646,7 +644,7 @@ def __init__(
fields: Optional[List["StructField"]] = None,
structured: Optional[bool] = False,
) -> None:
if STRUCTURED_TYPES_ENABLED:
if context._should_use_structured_type_semantics:
self.structured = (
structured if structured is not None else fields is not None
)
Expand Down
6 changes: 3 additions & 3 deletions src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""User-defined aggregate functions (UDAFs) in Snowpark. Refer to :class:`~snowflake.snowpark.udaf.UDAFRegistration` for details and sample code."""

import sys
import warnings
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -38,6 +37,7 @@
TempObjectType,
parse_positional_args_to_list,
publicapi,
warning,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.types import DataType, MapType
Expand Down Expand Up @@ -713,9 +713,9 @@ def _do_register_udaf(

if isinstance(return_type, MapType):
if return_type.structured:
warnings.warn(
warning(
"_do_register_udaf",
"Snowflake does not support structured maps as return type for UDAFs. Downcasting to semi-structured object.",
stacklevel=3,
)
return_type = MapType()

Expand Down
10 changes: 0 additions & 10 deletions tests/integ/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from snowflake.snowpark import Session
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock._connection import MockServerConnection
from snowflake.snowpark.types import STRUCTURED_TYPES_ENABLED
from tests.ast.ast_test_utils import (
close_full_ast_validation_mode,
setup_full_ast_validation_mode,
Expand Down Expand Up @@ -245,15 +244,6 @@ def session(
session._cte_optimization_enabled = cte_optimization_enabled
session.ast_enabled = ast_enabled

if STRUCTURED_TYPES_ENABLED:
queries = [
"alter session set ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE=true",
"alter session set IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE=true",
"alter session set FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT=true",
]
for q in queries:
session.sql(q).collect()

if os.getenv("GITHUB_ACTIONS") == "true" and not local_testing_mode:
set_up_external_access_integration_resources(
session, rule1, rule2, key1, key2, integration1, integration2
Expand Down

0 comments on commit ed232de

Please sign in to comment.