From 9ccac9ea012203ca0d2504874ec836b0a31fb832 Mon Sep 17 00:00:00 2001 From: Jamison Rose Date: Tue, 30 Apr 2024 14:52:01 -0700 Subject: [PATCH] SNOW-1000284: Add schema support for structure types (#1323) --- CHANGELOG.md | 1 + recipe/meta.yaml | 2 +- setup.py | 2 +- .../snowpark/_internal/type_utils.py | 52 +++- src/snowflake/snowpark/_internal/udf_utils.py | 2 +- src/snowflake/snowpark/types.py | 20 +- tests/integ/scala/test_datatype_suite.py | 239 +++++++++++++++++- tests/integ/test_stored_procedure.py | 63 +++++ tests/unit/test_types.py | 1 + tests/utils.py | 8 + 10 files changed, 377 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 212d4403392..d493d319f8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### New Features - Support stored procedure register with packages given as Python modules. +- Added support for structured type schema parsing. ### Local Testing Updates diff --git a/recipe/meta.yaml b/recipe/meta.yaml index e925aaec5f4..1637d64a1d5 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -26,7 +26,7 @@ requirements: - python - cloudpickle >=1.6.0,<=2.0.0 # [py<=310] - cloudpickle==2.2.1 # [py==311] - - snowflake-connector-python + - snowflake-connector-python >=3.10.0,<4.0.0 - typing-extensions >=4.1.0 # need to pin libffi because of problems in cryptography. # This might no longer hold true but keep it just to avoid it from biting us again diff --git a/setup.py b/setup.py index c6fca6c4256..59b1a30c2e5 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ MODIN_DEPENDENCY_VERSION = ( "==0.28.1" # Snowpark pandas requires modin 0.28.1, which depends on pandas 2.2.1 ) -CONNECTOR_DEPENDENCY_VERSION = ">=3.6.0, <4.0.0" +CONNECTOR_DEPENDENCY_VERSION = ">=3.10.0, <4.0.0" INSTALL_REQ_LIST = [ "setuptools>=40.6.0", "wheel", diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 026064fbc54..49d65d726f2 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -125,6 +125,37 @@ def convert_metadata_to_sp_type( raise ValueError( f"Invalid result metadata for vector type: invalid element type: {element_type_name}" ) + elif column_type_name in {"ARRAY", "MAP", "OBJECT"} and getattr( + metadata, "fields", None + ): + # If fields is not defined or empty then the legacy type can be returned instead + if column_type_name == "ARRAY": + assert ( + len(metadata.fields) == 1 + ), "ArrayType columns should have one metadata field." + return ArrayType( + convert_metadata_to_sp_type(metadata.fields[0]), structured=True + ) + elif column_type_name == "MAP": + assert ( + len(metadata.fields) == 2 + ), "MapType columns should have two metadata fields." + return MapType( + convert_metadata_to_sp_type(metadata.fields[0]), + convert_metadata_to_sp_type(metadata.fields[1]), + structured=True, + ) + else: + assert all( + getattr(field, "name", None) for field in metadata.fields + ), "All fields of a StructType should be named." + return StructType( + [ + StructField(field.name, convert_metadata_to_sp_type(field)) + for field in metadata.fields + ], + structured=True, + ) else: return convert_sf_to_sp_type( column_type_name, @@ -142,7 +173,7 @@ def convert_sf_to_sp_type( return ArrayType(StringType()) if column_type_name == "VARIANT": return VariantType() - if column_type_name == "OBJECT": + if column_type_name in {"OBJECT", "MAP"}: return MapType(StringType(), StringType()) if column_type_name == "GEOGRAPHY": return GeographyType() @@ -235,9 +266,24 @@ def convert_sp_to_sf_type(datatype: DataType) -> str: if isinstance(datatype, BinaryType): return "BINARY" if isinstance(datatype, ArrayType): - return "ARRAY" + if datatype.structured: + return f"ARRAY({convert_sp_to_sf_type(datatype.element_type)})" + else: + return "ARRAY" if isinstance(datatype, MapType): - return "OBJECT" + if datatype.structured: + return f"MAP({convert_sp_to_sf_type(datatype.key_type)}, {convert_sp_to_sf_type(datatype.value_type)})" + else: + return "OBJECT" + if isinstance(datatype, StructType): + if datatype.structured: + fields = ", ".join( + f"{field.name.upper()} {convert_sp_to_sf_type(field.datatype)}" + for field in datatype.fields + ) + return f"OBJECT({fields})" + else: + return "OBJECT" if isinstance(datatype, VariantType): return "VARIANT" if isinstance(datatype, GeographyType): diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index e460d728996..4dc04771c7f 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1163,7 +1163,7 @@ def create_python_udf_or_sp( if replace and if_not_exists: raise ValueError("options replace and if_not_exists are incompatible") - if isinstance(return_type, StructType): + if isinstance(return_type, StructType) and not return_type.structured: return_sql = f'RETURNS TABLE ({",".join(f"{field.name} {convert_sp_to_sf_type(field.datatype)}" for field in return_type.fields)})' elif installed_pandas and isinstance(return_type, PandasDataFrameType): return_sql = f'RETURNS TABLE ({",".join(f"{name} {convert_sp_to_sf_type(datatype)}" for name, datatype in zip(return_type.col_names, return_type.col_types))})' diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index 1f47dcbce7c..ec98d8da1f1 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -217,7 +217,10 @@ def __repr__(self) -> str: class ArrayType(DataType): """Array data type. This maps to the ARRAY data type in Snowflake.""" - def __init__(self, element_type: Optional[DataType] = None) -> None: + def __init__( + self, element_type: Optional[DataType] = None, structured: bool = False + ) -> None: + self.structured = structured self.element_type = element_type if element_type else StringType() def __repr__(self) -> str: @@ -228,11 +231,15 @@ def is_primitive(self): class MapType(DataType): - """Map data type. This maps to the OBJECT data type in Snowflake.""" + """Map data type. This maps to the OBJECT data type in Snowflake if key and value types are not defined otherwise MAP.""" def __init__( - self, key_type: Optional[DataType] = None, value_type: Optional[DataType] = None + self, + key_type: Optional[DataType] = None, + value_type: Optional[DataType] = None, + structured: bool = False, ) -> None: + self.structured = structured self.key_type = key_type if key_type else StringType() self.value_type = value_type if value_type else StringType() @@ -366,9 +373,12 @@ def __eq__(self, other): class StructType(DataType): - """Represents a table schema. Contains :class:`StructField` for each column.""" + """Represents a table schema or structured column. Contains :class:`StructField` for each field.""" - def __init__(self, fields: Optional[List["StructField"]] = None) -> None: + def __init__( + self, fields: Optional[List["StructField"]] = None, structured=False + ) -> None: + self.structured = structured if fields is None: fields = [] self.fields = fields diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index 50b6c79f000..835d7d1609d 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -2,13 +2,16 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import uuid + # Many of the tests have been moved to unit/scala/test_datattype_suite.py from decimal import Decimal import pytest from snowflake.snowpark import Row -from snowflake.snowpark.functions import lit +from snowflake.snowpark.exceptions import SnowparkSQLException +from snowflake.snowpark.functions import col, lit, udf from snowflake.snowpark.types import ( ArrayType, BinaryType, @@ -33,7 +36,69 @@ VariantType, VectorType, ) -from tests.utils import Utils +from tests.utils import IS_ICEBERG_SUPPORTED, IS_STRUCTURED_TYPES_SUPPORTED, Utils + +# Map of structured type enabled state to test params +STRUCTURED_TYPES_EXAMPLES = { + True: pytest.param( + """ + select + object_construct('k1', 1) :: map(varchar, int) as map, + object_construct('A', 'foo', 'B', 0.05) :: object(A varchar, B float) as obj, + [1.0, 3.1, 4.5] :: array(float) as arr + """, + [ + ("MAP", "map"), + ("OBJ", "struct"), + ("ARR", "array"), + ], + StructType( + [ + StructField( + "MAP", + MapType(StringType(16777216), LongType(), structured=True), + nullable=True, + ), + StructField( + "OBJ", + StructType( + [ + StructField("A", StringType(16777216), nullable=True), + StructField("B", DoubleType(), nullable=True), + ], + structured=True, + ), + nullable=True, + ), + StructField( + "ARR", ArrayType(DoubleType(), structured=True), nullable=True + ), + ] + ), + id="structured-types-enabled", + ), + False: pytest.param( + """ + select + object_construct('k1', 1) :: map(varchar, int) as map, + object_construct('A', 'foo', 'B', 0.05) :: object(A varchar, B float) as obj, + [1.0, 3.1, 4.5] :: array(float) as arr + """, + [ + ("MAP", "map"), + ("OBJ", "map"), + ("ARR", "array"), + ], + StructType( + [ + StructField("MAP", MapType(StringType(), StringType()), nullable=True), + StructField("OBJ", MapType(StringType(), StringType()), nullable=True), + StructField("ARR", ArrayType(StringType()), nullable=True), + ] + ), + id="legacy", + ), +} def test_verify_datatypes_reference(session): @@ -229,6 +294,176 @@ def test_dtypes(session): ] +@pytest.mark.parametrize( + "query,expected_dtypes,expected_schema", + [STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]], +) +def test_structured_dtypes(session, query, expected_dtypes, expected_schema): + df = session.sql(query) + assert df.schema == expected_schema + assert df.dtypes == expected_dtypes + + +@pytest.mark.parametrize( + "query,expected_dtypes,expected_schema", + [STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]], +) +def test_structured_dtypes_select(session, query, expected_dtypes, expected_schema): + df = session.sql(query) + flattened_df = df.select( + df.map["k1"].alias("value1"), + df.obj["A"].alias("a"), + col("obj")["B"].alias("b"), + df.arr[0].alias("value2"), + df.arr[1].alias("value3"), + col("arr")[2].alias("value4"), + ) + assert flattened_df.schema == StructType( + [ + StructField("VALUE1", LongType(), nullable=True), + StructField("A", StringType(16777216), nullable=True), + StructField("B", DoubleType(), nullable=True), + StructField("VALUE2", DoubleType(), nullable=True), + StructField("VALUE3", DoubleType(), nullable=True), + StructField("VALUE4", DoubleType(), nullable=True), + ] + ) + assert flattened_df.dtypes == [ + ("VALUE1", "bigint"), + ("A", "string(16777216)"), + ("B", "double"), + ("VALUE2", "double"), + ("VALUE3", "double"), + ("VALUE4", "double"), + ] + assert flattened_df.collect() == [ + Row(VALUE1=1, A="foo", B=0.05, VALUE2=1.0, VALUE3=3.1, VALUE4=4.5) + ] + + +@pytest.mark.parametrize( + "query,expected_dtypes,expected_schema", + [STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]], +) +def test_structured_dtypes_pandas(session, query, expected_dtypes, expected_schema): + pdf = session.sql(query).to_pandas() + if IS_STRUCTURED_TYPES_SUPPORTED: + assert ( + pdf.to_json() + == '{"MAP":{"0":[["k1",1.0]]},"OBJ":{"0":{"A":"foo","B":0.05}},"ARR":{"0":[1.0,3.1,4.5]}}' + ) + else: + assert ( + pdf.to_json() + == '{"MAP":{"0":"{\\n \\"k1\\": 1\\n}"},"OBJ":{"0":"{\\n \\"A\\": \\"foo\\",\\n \\"B\\": 5.000000000000000e-02\\n}"},"ARR":{"0":"[\\n 1.000000000000000e+00,\\n 3.100000000000000e+00,\\n 4.500000000000000e+00\\n]"}}' + ) + + +@pytest.mark.skip( + "SNOW-1356851: Skipping until iceberg testing infrastructure is added." +) +@pytest.mark.skipif( + not (IS_STRUCTURED_TYPES_SUPPORTED and IS_ICEBERG_SUPPORTED), + reason="Test requires iceberg support and structured type support.", +) +@pytest.mark.parametrize( + "query,expected_dtypes,expected_schema", + [STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]], +) +def test_structured_dtypes_iceberg(session, query, expected_dtypes, expected_schema): + table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}" + try: + session.sql( + f""" + create iceberg table if not exists {table_name} ( + map map(varchar, int), + obj object(a varchar, b float), + arr array(float) + ) + CATALOG = 'SNOWFLAKE' + EXTERNAL_VOLUME = 'python_connector_iceberg_exvol' + BASE_LOCATION = 'python_connector_merge_gate'; + """ + ).collect() + session.sql( + f""" + insert into {table_name} + {query} + """ + ).collect() + df = session.table(table_name) + assert df.schema == expected_schema + assert df.dtypes == expected_dtypes + finally: + session.sql(f"drop table if exists {table_name}") + + +@pytest.mark.skip( + "SNOW-1356851: Skipping until iceberg testing infrastructure is added." +) +@pytest.mark.skipif( + not (IS_STRUCTURED_TYPES_SUPPORTED and IS_ICEBERG_SUPPORTED), + reason="Test requires iceberg support and structured type support.", +) +@pytest.mark.parametrize( + "query,expected_dtypes,expected_schema", + [STRUCTURED_TYPES_EXAMPLES[IS_STRUCTURED_TYPES_SUPPORTED]], +) +def test_structured_dtypes_iceberg_udf( + session, query, expected_dtypes, expected_schema +): + table_name = f"snowpark_structured_dtypes_udf_test{uuid.uuid4().hex[:5]}" + + def nop(x): + return x + + (map_type, object_type, array_type) = expected_schema + nop_map_udf = udf( + nop, return_type=map_type.datatype, input_types=[map_type.datatype] + ) + nop_object_udf = udf( + nop, return_type=object_type.datatype, input_types=[object_type.datatype] + ) + nop_array_udf = udf( + nop, return_type=array_type.datatype, input_types=[array_type.datatype] + ) + + try: + session.sql( + f""" + create iceberg table if not exists {table_name} ( + map map(varchar, int), + obj object(A varchar, B float), + arr array(float) + ) + CATALOG = 'SNOWFLAKE' + EXTERNAL_VOLUME = 'python_connector_iceberg_exvol' + BASE_LOCATION = 'python_connector_merge_gate'; + """ + ).collect() + session.sql( + f""" + insert into {table_name} + {query} + """ + ).collect() + + df = session.table(table_name) + working = df.select( + nop_object_udf(col("obj")).alias("obj"), + nop_array_udf(col("arr")).alias("arr"), + ) + assert working.schema == StructType([object_type, array_type]) + + with pytest.raises(SnowparkSQLException): + # SNOW-XXXXXXX: Map not supported as a udf return type. + df.select( + nop_map_udf(col("map")).alias("map"), + ).collect() + finally: + session.sql(f"drop table if exists {table_name}") + + @pytest.mark.xfail(reason="SNOW-974852 vectors are not yet rolled out", strict=False) def test_dtypes_vector(session): schema = StructType( diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index e86298d1ab1..dbde387d86b 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -42,16 +42,21 @@ ) from snowflake.snowpark.row import Row from snowflake.snowpark.types import ( + ArrayType, DateType, DoubleType, IntegerType, + LongType, + MapType, StringType, StructField, StructType, + VectorType, ) from tests.utils import ( IS_IN_STORED_PROC, IS_NOT_ON_GITHUB, + IS_STRUCTURED_TYPES_SUPPORTED, TempObjectType, TestFiles, Utils, @@ -331,6 +336,64 @@ def test_call_named_stored_procedure(session, temp_schema, db_parameters): # restore active session +@pytest.mark.skipif( + not IS_STRUCTURED_TYPES_SUPPORTED, + reason="Structured types not enabled in this account.", +) +def test_stored_procedure_with_structured_returns(session): + expected_dtypes = [ + ("VEC", "vector"), + ("MAP", "map"), + ("OBJ", "struct"), + ("ARR", "array"), + ] + expected_schema = StructType( + [ + StructField("VEC", VectorType(int, 5), nullable=True), + StructField( + "MAP", + MapType(StringType(16777216), LongType(), structured=True), + nullable=True, + ), + StructField( + "OBJ", + StructType( + [ + StructField("A", StringType(16777216), nullable=True), + StructField("B", DoubleType(), nullable=True), + ], + structured=True, + ), + nullable=True, + ), + StructField("ARR", ArrayType(DoubleType(), structured=True), nullable=True), + ] + ) + + sproc_name = Utils.random_name_for_temp_object(TempObjectType.PROCEDURE) + + def test_sproc(session: Session) -> DataFrame: + return session.sql( + """ + select + [1,2,3,4,5] :: vector(int, 5) as vec, + object_construct('k1', 1) :: map(varchar, int) as map, + object_construct('a', 'foo', 'b', 0.05) :: object(a varchar, b float) as obj, + [1.0, 3.1, 4.5] :: array(float) as arr + ; + """ + ) + + session.sproc.register( + test_sproc, + name=sproc_name, + replace=True, + ) + df = session.call(sproc_name) + assert df.schema == expected_schema + assert df.dtypes == expected_dtypes + + @pytest.mark.localtest @pytest.mark.parametrize("anonymous", [True, False]) def test_call_table_sproc_triggers_action(session, anonymous): diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 0decfb2369f..7648517d8d4 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -784,6 +784,7 @@ def test_convert_sp_to_sf_type(): assert convert_sp_to_sf_type(BinaryType()) == "BINARY" assert convert_sp_to_sf_type(ArrayType()) == "ARRAY" assert convert_sp_to_sf_type(MapType()) == "OBJECT" + assert convert_sp_to_sf_type(StructType()) == "OBJECT" assert convert_sp_to_sf_type(VariantType()) == "VARIANT" assert convert_sp_to_sf_type(GeographyType()) == "GEOGRAPHY" assert convert_sp_to_sf_type(GeometryType()) == "GEOMETRY" diff --git a/tests/utils.py b/tests/utils.py index 89ac4d90a5c..326cae21c6a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -76,6 +76,14 @@ IS_NOT_ON_GITHUB = os.getenv("GITHUB_ACTIONS") != "true" # this env variable is set in regression test IS_IN_STORED_PROC_LOCALFS = IS_IN_STORED_PROC and os.getenv("IS_LOCAL_FS") +# SNOW-1348805: Structured types have not been rolled out to all accounts yet. +# Once rolled out this should be updated to include all accounts. +STRUCTURED_TYPE_ENVIRONMENTS = {"dev", "aws"} +IS_STRUCTURED_TYPES_SUPPORTED = ( + os.getenv("cloud_provider", "dev") in STRUCTURED_TYPE_ENVIRONMENTS +) +ICEBERG_ENVIRONMENTS = {"dev", "aws"} +IS_ICEBERG_SUPPORTED = os.getenv("cloud_provider", "dev") in ICEBERG_ENVIRONMENTS class Utils: