From 0b7982562caadeb5aedfccd974e7d30b01232126 Mon Sep 17 00:00:00 2001 From: Jamison Date: Fri, 15 Nov 2024 13:43:42 -0800 Subject: [PATCH 1/7] SNOW-1803811: Allow mixed-case field names for struct type columns --- CHANGELOG.md | 1 + .../snowpark/_internal/type_utils.py | 2 +- src/snowflake/snowpark/types.py | 5 ++++ tests/integ/scala/test_datatype_suite.py | 26 +++++++++---------- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14b9f0072aa..50a6d2415e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ - Added `save` method to `DataFrameWriter` to work in conjunction with `format`. - Added support to read keyword arguments to `options` method for `DataFrameReader` and `DataFrameWriter`. - Relaxed the cloudpickle dependency for Python 3.11 to simplify build requirements. However, for Python 3.11, `cloudpickle==2.2.1` remains the only supported version. +- Added support for mixed case field names in struct type columns. #### Bug Fixes diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 55fe27c9f8f..3b3f51bb56b 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -292,7 +292,7 @@ def convert_sp_to_sf_type(datatype: DataType) -> str: if isinstance(datatype, StructType): if datatype.structured: fields = ", ".join( - f"{field.name} {convert_sp_to_sf_type(field.datatype)}" + f"{field.raw_name} {convert_sp_to_sf_type(field.datatype)}" for field in datatype.fields ) return f"OBJECT({fields})" diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index 06bcc8969b5..0ee9724ee53 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -482,6 +482,7 @@ class ColumnIdentifier: """Represents a column identifier.""" def __init__(self, normalized_name: str) -> None: + self.raw_name = normalized_name self.normalized_name = quote_name(normalized_name) self._original_name = normalized_name @@ -566,6 +567,10 @@ def name(self) -> str: """Returns the column name.""" return self.column_identifier.name + @property + def raw_name(self) -> str: + return self.column_identifier.raw_name + @name.setter def name(self, n: str) -> None: self.column_identifier = ColumnIdentifier(n) diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index 63c24b7e35c..c43ea8a488b 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -60,7 +60,7 @@ _STRUCTURE_DATAFRAME_QUERY = """ select object_construct('k1', 1) :: map(varchar, int) as map, - object_construct('A', 'foo', 'B', 0.05) :: object(A varchar, B float) as obj, + object_construct('A', 'foo', 'b', 0.05) :: object(A varchar, b float) as obj, [1.0, 3.1, 4.5] :: array(float) as arr """ @@ -71,10 +71,10 @@ def _create_test_dataframe(s): object_construct(lit("k1"), lit(1)) .cast(MapType(StringType(), IntegerType(), structured=True)) .alias("map"), - object_construct(lit("A"), lit("foo"), lit("B"), lit(0.05)) + object_construct(lit("A"), lit("foo"), lit("b"), lit(0.05)) .cast( StructType( - [StructField("A", StringType()), StructField("B", DoubleType())], + [StructField("A", StringType()), StructField("b", DoubleType())], structured=True, ) ) @@ -106,7 +106,7 @@ def _create_test_dataframe(s): StructType( [ StructField("A", StringType(16777216), nullable=True), - StructField("B", DoubleType(), nullable=True), + StructField('"b"', DoubleType(), nullable=True), ], structured=True, ), @@ -386,7 +386,7 @@ def test_structured_dtypes_select(structured_type_session, examples): flattened_df = df.select( df.map["k1"].alias("value1"), df.obj["A"].alias("a"), - col("obj")["B"].alias("b"), + col("obj")["b"].alias("b"), df.arr[0].alias("value2"), df.arr[1].alias("value3"), col("arr")[2].alias("value4"), @@ -395,7 +395,7 @@ def test_structured_dtypes_select(structured_type_session, examples): [ StructField("VALUE1", LongType(), nullable=True), StructField("A", StringType(16777216), nullable=True), - StructField("B", DoubleType(), nullable=True), + StructField("b", DoubleType(), nullable=True), StructField("VALUE2", DoubleType(), nullable=True), StructField("VALUE3", DoubleType(), nullable=True), StructField("VALUE4", DoubleType(), nullable=True), @@ -424,12 +424,12 @@ def test_structured_dtypes_pandas(structured_type_session, structured_type_suppo if structured_type_support: assert ( pdf.to_json() - == '{"MAP":{"0":[["k1",1.0]]},"OBJ":{"0":{"A":"foo","B":0.05}},"ARR":{"0":[1.0,3.1,4.5]}}' + == '{"MAP":{"0":[["k1",1.0]]},"OBJ":{"0":{"A":"foo","b":0.05}},"ARR":{"0":[1.0,3.1,4.5]}}' ) else: assert ( pdf.to_json() - == '{"MAP":{"0":"{\\n \\"k1\\": 1\\n}"},"OBJ":{"0":"{\\n \\"A\\": \\"foo\\",\\n \\"B\\": 5.000000000000000e-02\\n}"},"ARR":{"0":"[\\n 1.000000000000000e+00,\\n 3.100000000000000e+00,\\n 4.500000000000000e+00\\n]"}}' + == '{"MAP":{"0":"{\\n \\"k1\\": 1\\n}"},"OBJ":{"0":"{\\n \\"A\\": \\"foo\\",\\n \\"b\\": 5.000000000000000e-02\\n}"},"ARR":{"0":"[\\n 1.000000000000000e+00,\\n 3.100000000000000e+00,\\n 4.500000000000000e+00\\n]"}}' ) @@ -467,7 +467,7 @@ def test_structured_dtypes_iceberg( ) assert save_ddl[0][0] == ( f"create or replace ICEBERG TABLE {table_name.upper()} (\n\t" - "MAP MAP(STRING, LONG),\n\tOBJ OBJECT(A STRING, B DOUBLE),\n\tARR ARRAY(DOUBLE)\n)\n " + "MAP MAP(STRING, LONG),\n\tOBJ OBJECT(A STRING, b DOUBLE),\n\tARR ARRAY(DOUBLE)\n)\n " "EXTERNAL_VOLUME = 'PYTHON_CONNECTOR_ICEBERG_EXVOL'\n CATALOG = 'SNOWFLAKE'\n " "BASE_LOCATION = 'python_connector_merge_gate/';" ) @@ -733,8 +733,8 @@ def test_structured_dtypes_iceberg_create_from_values( _, __, expected_schema = STRUCTURED_TYPES_EXAMPLES[True] table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}" data = [ - ({"x": 1}, {"A": "a", "B": 1}, [1, 1, 1]), - ({"x": 2}, {"A": "b", "B": 2}, [2, 2, 2]), + ({"x": 1}, {"A": "a", "b": 1}, [1, 1, 1]), + ({"x": 2}, {"A": "b", "b": 2}, [2, 2, 2]), ] try: create_df = structured_type_session.create_dataframe( @@ -945,8 +945,8 @@ def test_structured_type_print_schema( " | |-- key: StringType()\n" " | |-- value: ArrayType\n" " | | |-- element: StructType\n" - ' | | | |-- "FIELD1": StringType() (nullable = True)\n' - ' | | | |-- "FIELD2": LongType() (nullable = True)\n' + ' | | | |-- "Field1": StringType() (nullable = True)\n' + ' | | | |-- "Field2": LongType() (nullable = True)\n' ) # Test that depth works as expected From f05c058540785cf684598dbd66f0e2228b553024 Mon Sep 17 00:00:00 2001 From: Jamison Date: Tue, 3 Dec 2024 13:09:41 -0800 Subject: [PATCH 2/7] Refactor based on feedback --- .../snowpark/_internal/type_utils.py | 6 +- src/snowflake/snowpark/column.py | 6 ++ src/snowflake/snowpark/types.py | 75 +++++++++++++------ tests/integ/scala/test_datatype_suite.py | 16 ++-- 4 files changed, 69 insertions(+), 34 deletions(-) diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 3b3f51bb56b..49dbe3bb97f 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -34,7 +34,6 @@ from snowflake.connector.constants import FIELD_ID_TO_NAME from snowflake.connector.cursor import ResultMetadata from snowflake.connector.options import installed_pandas, pandas -from snowflake.snowpark._internal.utils import quote_name from snowflake.snowpark.types import ( LTZ, NTZ, @@ -157,9 +156,10 @@ def convert_metadata_to_sp_type( return StructType( [ StructField( - quote_name(field.name, keep_case=True), + field.name, convert_metadata_to_sp_type(field, max_string_size), nullable=field.is_nullable, + is_column=False, ) for field in metadata.fields ], @@ -292,7 +292,7 @@ def convert_sp_to_sf_type(datatype: DataType) -> str: if isinstance(datatype, StructType): if datatype.structured: fields = ", ".join( - f"{field.raw_name} {convert_sp_to_sf_type(field.datatype)}" + f"{field.name} {convert_sp_to_sf_type(field.datatype)}" for field in datatype.fields ) return f"OBJECT({fields})" diff --git a/src/snowflake/snowpark/column.py b/src/snowflake/snowpark/column.py index aacb647bbad..4f730af5a06 100644 --- a/src/snowflake/snowpark/column.py +++ b/src/snowflake/snowpark/column.py @@ -90,6 +90,9 @@ StringType, TimestampTimeZone, TimestampType, + ArrayType, + MapType, + StructType, ) from snowflake.snowpark.window import Window, WindowSpec @@ -916,6 +919,9 @@ def _cast( if isinstance(to, str): to = type_string_to_type_object(to) + if isinstance(to, (ArrayType, MapType, StructType)): + to = to._as_nested() + if self._ast is None: _emit_ast = False diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index 0ee9724ee53..6cb0b90b116 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -341,6 +341,12 @@ def __init__( def __repr__(self) -> str: return f"ArrayType({repr(self.element_type) if self.element_type else ''})" + def _as_nested(self) -> "ArrayType": + element_type = self.element_type + if isinstance(element_type, (ArrayType, MapType, StructType)): + element_type = element_type._as_nested() + return ArrayType(element_type, self.structured) + def is_primitive(self): return False @@ -391,6 +397,12 @@ def __repr__(self) -> str: def is_primitive(self): return False + def _as_nested(self) -> "MapType": + value_type = self.value_type + if isinstance(value_type, (ArrayType, MapType, StructType)): + value_type = value_type._as_nested() + return MapType(self.key_type, value_type, self.structured) + @classmethod def from_json(cls, json_dict: Dict[str, Any]) -> "MapType": return MapType( @@ -482,7 +494,6 @@ class ColumnIdentifier: """Represents a column identifier.""" def __init__(self, normalized_name: str) -> None: - self.raw_name = normalized_name self.normalized_name = quote_name(normalized_name) self._original_name = normalized_name @@ -553,33 +564,41 @@ def __init__( column_identifier: Union[ColumnIdentifier, str], datatype: DataType, nullable: bool = True, + is_column: bool = True, ) -> None: - self.column_identifier = ( - ColumnIdentifier(column_identifier) - if isinstance(column_identifier, str) - else column_identifier - ) + self.name = column_identifier + self.is_column = is_column self.datatype = datatype self.nullable = nullable @property def name(self) -> str: - """Returns the column name.""" - return self.column_identifier.name - - @property - def raw_name(self) -> str: - return self.column_identifier.raw_name + return self.column_identifier.name if self.is_column else self._name @name.setter - def name(self, n: str) -> None: - self.column_identifier = ColumnIdentifier(n) + def name(self, n: Union[ColumnIdentifier, str]) -> None: + if isinstance(n, ColumnIdentifier): + self._name = n.name + self.column_identifier = n + else: + self._name = n + self.column_identifier = ColumnIdentifier(n) + + def _as_nested(self) -> "StructField": + datatype = self.datatype + if isinstance(datatype, (ArrayType, MapType, StructType)): + datatype = datatype._as_nested() + # Nested StructFields do not follow column naming conventions + return StructField(self._name, datatype, self.nullable, is_column=False) def __repr__(self) -> str: return f"StructField({self.name!r}, {repr(self.datatype)}, nullable={self.nullable})" def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + return isinstance(other, self.__class__) and ( + (self.name, self.is_column, self.datatype, self.nullable) + == (other.name, other.is_column, other.datatype, other.nullable) + ) @classmethod def from_json(cls, json_dict: Dict[str, Any]) -> "StructField": @@ -625,9 +644,9 @@ def __init__( self, fields: Optional[List["StructField"]] = None, structured=False ) -> None: self.structured = structured - if fields is None: - fields = [] - self.fields = fields + self.fields = [] + for field in fields: + self.add(field) def add( self, @@ -635,20 +654,30 @@ def add( datatype: Optional[DataType] = None, nullable: Optional[bool] = True, ) -> "StructType": - if isinstance(field, StructField): - self.fields.append(field) - elif isinstance(field, (str, ColumnIdentifier)): + if isinstance(field, (str, ColumnIdentifier)): if datatype is None: raise ValueError( "When field argument is str or ColumnIdentifier, datatype must not be None." ) - self.fields.append(StructField(field, datatype, nullable)) - else: + field = StructField(field, datatype, nullable) + elif not isinstance(field, StructField): + __import__("pdb").set_trace() raise ValueError( f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{type(field)}'" ) + + # Nested data does not follow the same schema conventions as top level fields. + if isinstance(field.datatype, (ArrayType, MapType, StructType)): + field.datatype = field.datatype._as_nested() + + self.fields.append(field) return self + def _as_nested(self) -> "StructType": + return StructType( + [field._as_nested() for field in self.fields], self.structured + ) + @classmethod def _from_attributes(cls, attributes: list) -> "StructType": return cls([StructField(a.name, a.datatype, a.nullable) for a in attributes]) diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index c43ea8a488b..22324bf7bea 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -106,7 +106,7 @@ def _create_test_dataframe(s): StructType( [ StructField("A", StringType(16777216), nullable=True), - StructField('"b"', DoubleType(), nullable=True), + StructField("b", DoubleType(), nullable=True), ], structured=True, ), @@ -524,27 +524,27 @@ def test_iceberg_nested_fields( "NESTED_DATA", StructType( [ - StructField('"camelCase"', StringType(), nullable=True), - StructField('"snake_case"', StringType(), nullable=True), - StructField('"PascalCase"', StringType(), nullable=True), + StructField("camelCase", StringType(), nullable=True), + StructField("snake_case", StringType(), nullable=True), + StructField("PascalCase", StringType(), nullable=True), StructField( - '"nested_map"', + "nested_map", MapType( StringType(), StructType( [ StructField( - '"inner_camelCase"', + "inner_camelCase", StringType(), nullable=True, ), StructField( - '"inner_snake_case"', + "inner_snake_case", StringType(), nullable=True, ), StructField( - '"inner_PascalCase"', + "inner_PascalCase", StringType(), nullable=True, ), From aa6dffd8eac99c5f919ae3d411cff788e2f9568f Mon Sep 17 00:00:00 2001 From: Jamison Date: Thu, 5 Dec 2024 15:08:28 -0800 Subject: [PATCH 3/7] test fix --- src/snowflake/snowpark/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index 6cb0b90b116..1c63ad54d35 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -645,7 +645,7 @@ def __init__( ) -> None: self.structured = structured self.fields = [] - for field in fields: + for field in self.fields: self.add(field) def add( From 3070b90bee8d9ff9c69ac03d7689127ce0b5ed26 Mon Sep 17 00:00:00 2001 From: Jamison Date: Thu, 5 Dec 2024 18:46:53 -0800 Subject: [PATCH 4/7] additional test fixes --- src/snowflake/snowpark/types.py | 3 +-- tests/integ/test_stored_procedure.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index 1c63ad54d35..f50a644043e 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -645,7 +645,7 @@ def __init__( ) -> None: self.structured = structured self.fields = [] - for field in self.fields: + for field in fields or []: self.add(field) def add( @@ -661,7 +661,6 @@ def add( ) field = StructField(field, datatype, nullable) elif not isinstance(field, StructField): - __import__("pdb").set_trace() raise ValueError( f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{type(field)}'" ) diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index 20c63d78642..9345bca0bb8 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -388,8 +388,8 @@ def test_stored_procedure_with_structured_returns( "OBJ", StructType( [ - StructField('"a"', StringType(16777216), nullable=True), - StructField('"b"', DoubleType(), nullable=True), + StructField("a", StringType(16777216), nullable=True), + StructField("b", DoubleType(), nullable=True), ], structured=True, ), From d7bf6afe06e45981f08864c6738c0b51fb7b4f2f Mon Sep 17 00:00:00 2001 From: Jamison Date: Mon, 16 Dec 2024 10:36:57 -0800 Subject: [PATCH 5/7] add feature gate --- CHANGELOG.md | 2 +- .../snowpark/_internal/type_utils.py | 6 +- src/snowflake/snowpark/context.py | 4 + src/snowflake/snowpark/types.py | 14 ++- tests/integ/scala/test_datatype_suite.py | 116 ++++++++++-------- 5 files changed, 86 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f17094e59f3..51573adafa7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ - Added support for applying Snowpark Python function `snowflake_cortex_sentiment`. - Added support for `DataFrame.map`. - Added support for `DataFrame.from_dict` and `DataFrame.from_records`. +- Added support for mixed case field names in struct type columns. #### Improvements - Improve performance of `DataFrame.map`, `Series.apply` and `Series.map` methods by mapping numpy functions to snowpark functions if possible. @@ -68,7 +69,6 @@ - Added `save` method to `DataFrameWriter` to work in conjunction with `format`. - Added support to read keyword arguments to `options` method for `DataFrameReader` and `DataFrameWriter`. - Relaxed the cloudpickle dependency for Python 3.11 to simplify build requirements. However, for Python 3.11, `cloudpickle==2.2.1` remains the only supported version. -- Added support for mixed case field names in struct type columns. #### Bug Fixes diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 49dbe3bb97f..09a0fe3035e 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -30,10 +30,12 @@ get_origin, ) +import snowflake.snowpark.context as context import snowflake.snowpark.types # type: ignore from snowflake.connector.constants import FIELD_ID_TO_NAME from snowflake.connector.cursor import ResultMetadata from snowflake.connector.options import installed_pandas, pandas +from snowflake.snowpark._internal.utils import quote_name from snowflake.snowpark.types import ( LTZ, NTZ, @@ -156,7 +158,9 @@ def convert_metadata_to_sp_type( return StructType( [ StructField( - field.name, + field.name + if context._should_use_structured_type_semantics + else quote_name(field.name, keep_case=True), convert_metadata_to_sp_type(field, max_string_size), nullable=field.is_nullable, is_column=False, diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index c8f6888c5bd..8bc86f928a1 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -21,6 +21,10 @@ _should_continue_registration: Optional[Callable[..., bool]] = None +# Global flag that determines if structured type semantics should be used +_should_use_structured_type_semantics = False + + def get_active_session() -> "snowflake.snowpark.Session": """Returns the current active Snowpark session. diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index f50a644043e..ac21f125aa8 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -16,6 +16,7 @@ # Use correct version from here: from snowflake.snowpark._internal.utils import installed_pandas, pandas, quote_name +import snowflake.snowpark.context as context # TODO: connector installed_pandas is broken. If pyarrow is not installed, but pandas is this function returns the wrong answer. # The core issue is that in the connector detection of both pandas/arrow are mixed, which is wrong. @@ -342,6 +343,8 @@ def __repr__(self) -> str: return f"ArrayType({repr(self.element_type) if self.element_type else ''})" def _as_nested(self) -> "ArrayType": + if not context._should_use_structured_type_semantics: + return self element_type = self.element_type if isinstance(element_type, (ArrayType, MapType, StructType)): element_type = element_type._as_nested() @@ -398,6 +401,8 @@ def is_primitive(self): return False def _as_nested(self) -> "MapType": + if not context._should_use_structured_type_semantics: + return self value_type = self.value_type if isinstance(value_type, (ArrayType, MapType, StructType)): value_type = value_type._as_nested() @@ -573,7 +578,10 @@ def __init__( @property def name(self) -> str: - return self.column_identifier.name if self.is_column else self._name + if self.is_column or not context._should_use_structured_type_semantics: + return self.column_identifier.name + else: + return self._name @name.setter def name(self, n: Union[ColumnIdentifier, str]) -> None: @@ -585,6 +593,8 @@ def name(self, n: Union[ColumnIdentifier, str]) -> None: self.column_identifier = ColumnIdentifier(n) def _as_nested(self) -> "StructField": + if not context._should_use_structured_type_semantics: + return self datatype = self.datatype if isinstance(datatype, (ArrayType, MapType, StructType)): datatype = datatype._as_nested() @@ -673,6 +683,8 @@ def add( return self def _as_nested(self) -> "StructType": + if not context._should_use_structured_type_semantics: + return self return StructType( [field._as_nested() for field in self.fields], self.structured ) diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index 22324bf7bea..f84f237baf2 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -9,6 +9,7 @@ import pytest +import snowflake.snowpark.context as context from snowflake.connector.options import installed_pandas from snowflake.snowpark import Row from snowflake.snowpark.exceptions import SnowparkSQLException @@ -86,55 +87,6 @@ def _create_test_dataframe(s): return df -STRUCTURED_TYPES_EXAMPLES = { - True: ( - _STRUCTURE_DATAFRAME_QUERY, - [ - ("MAP", "map"), - ("OBJ", "struct"), - ("ARR", "array"), - ], - StructType( - [ - StructField( - "MAP", - MapType(StringType(16777216), LongType(), structured=True), - nullable=True, - ), - StructField( - "OBJ", - StructType( - [ - StructField("A", StringType(16777216), nullable=True), - StructField("b", DoubleType(), nullable=True), - ], - structured=True, - ), - nullable=True, - ), - StructField( - "ARR", ArrayType(DoubleType(), structured=True), nullable=True - ), - ] - ), - ), - False: ( - _STRUCTURE_DATAFRAME_QUERY, - [ - ("MAP", "map"), - ("OBJ", "map"), - ("ARR", "array"), - ], - StructType( - [ - StructField("MAP", MapType(StringType(), StringType()), nullable=True), - StructField("OBJ", MapType(StringType(), StringType()), nullable=True), - StructField("ARR", ArrayType(StringType()), nullable=True), - ] - ), - ), -} - ICEBERG_CONFIG = { "catalog": "SNOWFLAKE", "external_volume": "python_connector_iceberg_exvol", @@ -142,6 +94,61 @@ def _create_test_dataframe(s): } +def _create_example(structured_types_enabled): + if structured_types_enabled: + return ( + _STRUCTURE_DATAFRAME_QUERY, + [ + ("MAP", "map"), + ("OBJ", "struct"), + ("ARR", "array"), + ], + StructType( + [ + StructField( + "MAP", + MapType(StringType(16777216), LongType(), structured=True), + nullable=True, + ), + StructField( + "OBJ", + StructType( + [ + StructField("A", StringType(16777216), nullable=True), + StructField("b", DoubleType(), nullable=True), + ], + structured=True, + ), + nullable=True, + ), + StructField( + "ARR", ArrayType(DoubleType(), structured=True), nullable=True + ), + ] + ), + ) + else: + return ( + _STRUCTURE_DATAFRAME_QUERY, + [ + ("MAP", "map"), + ("OBJ", "map"), + ("ARR", "array"), + ], + StructType( + [ + StructField( + "MAP", MapType(StringType(), StringType()), nullable=True + ), + StructField( + "OBJ", MapType(StringType(), StringType()), nullable=True + ), + StructField("ARR", ArrayType(StringType()), nullable=True), + ] + ), + ) + + @pytest.fixture(scope="module") def structured_type_support(session, local_testing_mode): yield structured_types_supported(session, local_testing_mode) @@ -149,14 +156,17 @@ def structured_type_support(session, local_testing_mode): @pytest.fixture(scope="module") def examples(structured_type_support): - yield STRUCTURED_TYPES_EXAMPLES[structured_type_support] + yield _create_example(structured_type_support) @pytest.fixture(scope="module") def structured_type_session(session, structured_type_support): if structured_type_support: with structured_types_enabled_session(session) as sess: + semantics_enabled = context._should_use_structured_type_semantics + context._should_use_structured_type_semantics = True yield sess + context._should_use_structured_type_semantics = semantics_enabled else: yield session @@ -445,7 +455,7 @@ def test_structured_dtypes_iceberg( and iceberg_supported(structured_type_session, local_testing_mode) ): pytest.skip("Test requires iceberg support and structured type support.") - query, expected_dtypes, expected_schema = STRUCTURED_TYPES_EXAMPLES[True] + query, expected_dtypes, expected_schema = _create_example(True) table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}".upper() dynamic_table_name = f"snowpark_dynamic_iceberg_{uuid.uuid4().hex[:5]}".upper() @@ -730,7 +740,7 @@ def test_structured_dtypes_iceberg_create_from_values( ): pytest.skip("Test requires iceberg support and structured type support.") - _, __, expected_schema = STRUCTURED_TYPES_EXAMPLES[True] + _, __, expected_schema = _create_example(True) table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}" data = [ ({"x": 1}, {"A": "a", "b": 1}, [1, 1, 1]), @@ -760,7 +770,7 @@ def test_structured_dtypes_iceberg_udf( and iceberg_supported(structured_type_session, local_testing_mode) ): pytest.skip("Test requires iceberg support and structured type support.") - query, expected_dtypes, expected_schema = STRUCTURED_TYPES_EXAMPLES[True] + query, expected_dtypes, expected_schema = _create_example(True) table_name = f"snowpark_structured_dtypes_udf_test{uuid.uuid4().hex[:5]}" From e34010accdd64b9628725634644b84b4cf636a3c Mon Sep 17 00:00:00 2001 From: Jamison Date: Mon, 16 Dec 2024 11:01:57 -0800 Subject: [PATCH 6/7] test fixes --- tests/integ/scala/test_datatype_suite.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index f84f237baf2..3adcc2dc6b1 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -401,11 +401,12 @@ def test_structured_dtypes_select(structured_type_session, examples): df.arr[1].alias("value3"), col("arr")[2].alias("value4"), ) + nested_field_name = "b" if context._should_use_structured_type_semantics else "B" assert flattened_df.schema == StructType( [ StructField("VALUE1", LongType(), nullable=True), StructField("A", StringType(16777216), nullable=True), - StructField("b", DoubleType(), nullable=True), + StructField(nested_field_name, DoubleType(), nullable=True), StructField("VALUE2", DoubleType(), nullable=True), StructField("VALUE3", DoubleType(), nullable=True), StructField("VALUE4", DoubleType(), nullable=True), @@ -439,7 +440,7 @@ def test_structured_dtypes_pandas(structured_type_session, structured_type_suppo else: assert ( pdf.to_json() - == '{"MAP":{"0":"{\\n \\"k1\\": 1\\n}"},"OBJ":{"0":"{\\n \\"A\\": \\"foo\\",\\n \\"b\\": 5.000000000000000e-02\\n}"},"ARR":{"0":"[\\n 1.000000000000000e+00,\\n 3.100000000000000e+00,\\n 4.500000000000000e+00\\n]"}}' + == '{"MAP":{"0":"{\\n \\"k1\\": 1\\n}"},"OBJ":{"0":"{\\n \\"A\\": \\"foo\\",\\n \\"B\\": 5.000000000000000e-02\\n}"},"ARR":{"0":"[\\n 1.000000000000000e+00,\\n 3.100000000000000e+00,\\n 4.500000000000000e+00\\n]"}}' ) From 9ab8eda9b452dde23b28208a65839ac022034958 Mon Sep 17 00:00:00 2001 From: Jamison Date: Mon, 16 Dec 2024 14:36:42 -0800 Subject: [PATCH 7/7] review feedback --- .../snowpark/_internal/type_utils.py | 2 +- src/snowflake/snowpark/types.py | 12 +++---- tests/integ/scala/test_datatype_suite.py | 36 +++++++++++-------- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 09a0fe3035e..3d1095132ab 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -163,7 +163,7 @@ def convert_metadata_to_sp_type( else quote_name(field.name, keep_case=True), convert_metadata_to_sp_type(field, max_string_size), nullable=field.is_nullable, - is_column=False, + _is_column=False, ) for field in metadata.fields ], diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index ac21f125aa8..333fc580f60 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -569,16 +569,16 @@ def __init__( column_identifier: Union[ColumnIdentifier, str], datatype: DataType, nullable: bool = True, - is_column: bool = True, + _is_column: bool = True, ) -> None: self.name = column_identifier - self.is_column = is_column + self._is_column = _is_column self.datatype = datatype self.nullable = nullable @property def name(self) -> str: - if self.is_column or not context._should_use_structured_type_semantics: + if self._is_column or not context._should_use_structured_type_semantics: return self.column_identifier.name else: return self._name @@ -599,15 +599,15 @@ def _as_nested(self) -> "StructField": if isinstance(datatype, (ArrayType, MapType, StructType)): datatype = datatype._as_nested() # Nested StructFields do not follow column naming conventions - return StructField(self._name, datatype, self.nullable, is_column=False) + return StructField(self._name, datatype, self.nullable, _is_column=False) def __repr__(self) -> str: return f"StructField({self.name!r}, {repr(self.datatype)}, nullable={self.nullable})" def __eq__(self, other): return isinstance(other, self.__class__) and ( - (self.name, self.is_column, self.datatype, self.nullable) - == (other.name, other.is_column, other.datatype, other.nullable) + (self.name, self._is_column, self.datatype, self.nullable) + == (other.name, other._is_column, other.datatype, other.nullable) ) @classmethod diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index 3adcc2dc6b1..a1bd1d48acd 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -58,7 +58,7 @@ # make sure dataframe creation is the same as _create_test_dataframe -_STRUCTURE_DATAFRAME_QUERY = """ +_STRUCTURED_DATAFRAME_QUERY = """ select object_construct('k1', 1) :: map(varchar, int) as map, object_construct('A', 'foo', 'b', 0.05) :: object(A varchar, b float) as obj, @@ -66,16 +66,20 @@ """ -# make sure dataframe creation is the same as _STRUCTURE_DATAFRAME_QUERY -def _create_test_dataframe(s): +# make sure dataframe creation is the same as _STRUCTURED_DATAFRAME_QUERY +def _create_test_dataframe(s, structured_type_support): + nested_field_name = "b" if structured_type_support else "B" df = s.create_dataframe([1], schema=["a"]).select( object_construct(lit("k1"), lit(1)) .cast(MapType(StringType(), IntegerType(), structured=True)) .alias("map"), - object_construct(lit("A"), lit("foo"), lit("b"), lit(0.05)) + object_construct(lit("A"), lit("foo"), lit(nested_field_name), lit(0.05)) .cast( StructType( - [StructField("A", StringType()), StructField("b", DoubleType())], + [ + StructField("A", StringType()), + StructField(nested_field_name, DoubleType()), + ], structured=True, ) ) @@ -97,7 +101,7 @@ def _create_test_dataframe(s): def _create_example(structured_types_enabled): if structured_types_enabled: return ( - _STRUCTURE_DATAFRAME_QUERY, + _STRUCTURED_DATAFRAME_QUERY, [ ("MAP", "map"), ("OBJ", "struct"), @@ -129,7 +133,7 @@ def _create_example(structured_types_enabled): ) else: return ( - _STRUCTURE_DATAFRAME_QUERY, + _STRUCTURED_DATAFRAME_QUERY, [ ("MAP", "map"), ("OBJ", "map"), @@ -375,9 +379,9 @@ def test_dtypes(session): "config.getoption('local_testing_mode', default=False)", reason="FEAT: SNOW-1372813 Cast to StructType not supported", ) -def test_structured_dtypes(structured_type_session, examples): +def test_structured_dtypes(structured_type_session, examples, structured_type_support): query, expected_dtypes, expected_schema = examples - df = _create_test_dataframe(structured_type_session) + df = _create_test_dataframe(structured_type_session, structured_type_support) assert df.schema == expected_schema assert df.dtypes == expected_dtypes @@ -390,18 +394,20 @@ def test_structured_dtypes(structured_type_session, examples): "config.getoption('local_testing_mode', default=False)", reason="FEAT: SNOW-1372813 Cast to StructType not supported", ) -def test_structured_dtypes_select(structured_type_session, examples): +def test_structured_dtypes_select( + structured_type_session, examples, structured_type_support +): query, expected_dtypes, expected_schema = examples - df = _create_test_dataframe(structured_type_session) + df = _create_test_dataframe(structured_type_session, structured_type_support) + nested_field_name = "b" if context._should_use_structured_type_semantics else "B" flattened_df = df.select( df.map["k1"].alias("value1"), df.obj["A"].alias("a"), - col("obj")["b"].alias("b"), + col("obj")[nested_field_name].alias("b"), df.arr[0].alias("value2"), df.arr[1].alias("value3"), col("arr")[2].alias("value4"), ) - nested_field_name = "b" if context._should_use_structured_type_semantics else "B" assert flattened_df.schema == StructType( [ StructField("VALUE1", LongType(), nullable=True), @@ -431,7 +437,9 @@ def test_structured_dtypes_select(structured_type_session, examples): reason="FEAT: SNOW-1372813 Cast to StructType not supported", ) def test_structured_dtypes_pandas(structured_type_session, structured_type_support): - pdf = _create_test_dataframe(structured_type_session).to_pandas() + pdf = _create_test_dataframe( + structured_type_session, structured_type_support + ).to_pandas() if structured_type_support: assert ( pdf.to_json()