From 5ed85555b84d9d96723bad165b23f970a8010e81 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 16 Dec 2024 06:50:11 -0600 Subject: [PATCH] Add support for Structured ARRAY --- README.md | 73 +++++++++ src/snowflake/sqlalchemy/base.py | 5 +- src/snowflake/sqlalchemy/custom_types.py | 16 +- .../sqlalchemy/parser/custom_type_parser.py | 52 ++++--- .../test_structured_datatypes.ambr | 49 ++++++ tests/test_structured_datatypes.py | 141 ++++++++++++++++-- tests/test_unit_structured_types.py | 4 + 7 files changed, 308 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 3985a9d6..34d86376 100644 --- a/README.md +++ b/README.md @@ -367,6 +367,79 @@ data_object = json.loads(row[1]) data_array = json.loads(row[2]) ``` +### Structured Data Types Support + +This module defines custom SQLAlchemy types for Snowflake structured data, specifically for **Iceberg tables**. +The types —**MAP**, **OBJECT**, and **ARRAY**— allow you to store complex data structures in your SQLAlchemy models. +For detailed information, refer to the Snowflake [Structured data types](https://docs.snowflake.com/en/sql-reference/data-types-structured) documentation. + +--- + +#### MAP + +The `MAP` type represents a collection of key-value pairs, where each key and value can have different types. + +- **Key Type**: The type of the keys (e.g., `TEXT`, `NUMBER`). +- **Value Type**: The type of the values (e.g., `TEXT`, `NUMBER`). +- **Not Null**: Whether `NULL` values are allowed (default is `False`). + +*Example Usage* + +```python +IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_col", MAP(NUMBER(10, 0), TEXT(16777216))), + external_volume="external_volume", + base_location="base_location", +) +``` + +#### OBJECT + +The `OBJECT` type represents a semi-structured object with named fields. Each field can have a specific type, and you can also specify whether each field is nullable. + +- **Items Types**: A dictionary of field names and their types. The type can optionally include a nullable flag (`True` for not nullable, `False` for nullable, default is `False`). + +*Example Usage* + +```python +IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + OBJECT(key1=TEXT(16777216), key2=NUMBER(10, 0)), # Without nullable flag + ), + external_volume="external_volume", + base_location="base_location", +) +``` + +#### ARRAY + +The `ARRAY` type represents an ordered list of values, where each element has the same type. The type of the elements is defined when creating the array. + +- **Value Type**: The type of the elements in the array (e.g., `TEXT`, `NUMBER`). +- **Not Null**: Whether `NULL` values are allowed (default is `False`). + +*Example Usage* + +```python +IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("array_col", ARRAY(TEXT(16777216))), + external_volume="external_volume", + base_location="base_location", +) +``` + + ### CLUSTER BY Support Snowflake SQLAchemy supports the `CLUSTER BY` parameter for tables. For information about the parameter, see :doc:`/sql-reference/sql/create-table`. diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 3fef7709..9ce8b83c 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -1096,7 +1096,10 @@ def visit_MAP(self, type_, **kw): ) def visit_ARRAY(self, type_, **kw): - return "ARRAY" + if type_.is_semi_structured: + return "ARRAY" + not_null = f" {NOT_NULL}" if type_.not_null else "" + return f"ARRAY({type_.value_type.compile()}{not_null})" def visit_OBJECT(self, type_, **kw): if type_.is_semi_structured: diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index ce7ad592..11cd2eb8 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -1,7 +1,7 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from typing import Tuple, Union +from typing import Optional, Tuple, Union import sqlalchemy.types as sqltypes import sqlalchemy.util as util @@ -40,7 +40,8 @@ class VARIANT(SnowflakeType): class StructuredType(SnowflakeType): - def __init__(self): + def __init__(self, is_semi_structured: bool = False): + self.is_semi_structured = is_semi_structured super().__init__() @@ -81,9 +82,18 @@ def __repr__(self): ) -class ARRAY(SnowflakeType): +class ARRAY(StructuredType): __visit_name__ = "ARRAY" + def __init__( + self, + value_type: Optional[sqltypes.TypeEngine] = None, + not_null: bool = False, + ): + self.value_type = value_type + self.not_null = not_null + super().__init__(is_semi_structured=value_type is None) + class TIMESTAMP_TZ(SnowflakeType): __visit_name__ = "TIMESTAMP_TZ" diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py index 1e99ba56..09cb6ab8 100644 --- a/src/snowflake/sqlalchemy/parser/custom_type_parser.py +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -74,6 +74,8 @@ "GEOMETRY": GEOMETRY, } +NOT_NULL_STR = "NOT NULL" + def tokenize_parameters(text: str, character_for_strip=",") -> list: """ @@ -160,6 +162,8 @@ def parse_type(type_text: str) -> TypeEngine: col_type_kw = __parse_map_type_parameters(parameters) elif issubclass(col_type_class, OBJECT): col_type_kw = __parse_object_type_parameters(parameters) + elif issubclass(col_type_class, ARRAY): + col_type_kw = __parse_nullable_parameter(parameters) if col_type_kw is None: col_type_class = NullType col_type_kw = {} @@ -169,6 +173,7 @@ def parse_type(type_text: str) -> TypeEngine: def __parse_object_type_parameters(parameters): object_rows = {} + not_null_parts = NOT_NULL_STR.split(" ") for parameter in parameters: parameter_parts = tokenize_parameters(parameter, " ") if len(parameter_parts) >= 2: @@ -178,40 +183,51 @@ def __parse_object_type_parameters(parameters): return None not_null = ( len(parameter_parts) == 4 - and parameter_parts[2] == "NOT" - and parameter_parts[3] == "NULL" + and parameter_parts[2] == not_null_parts[0] + and parameter_parts[3] == not_null_parts[1] ) object_rows[key] = (value_type, not_null) return object_rows -def __parse_map_type_parameters(parameters): - if len(parameters) != 2: +def __parse_nullable_parameter(parameters): + if len(parameters) < 1: + return {} + elif len(parameters) > 1: return None - - key_type_str = parameters[0] - value_type_str = parameters[1] - not_null_str = "NOT NULL" - not_null = False + parameter_str = parameters[0] + is_not_null = False if ( - len(value_type_str) >= len(not_null_str) - and value_type_str[-len(not_null_str) :] == not_null_str + len(parameter_str) >= len(NOT_NULL_STR) + and parameter_str[-len(NOT_NULL_STR) :] == NOT_NULL_STR ): - not_null = True - value_type_str = value_type_str[: -len(not_null_str) - 1] + is_not_null = True + parameter_str = parameter_str[: -len(NOT_NULL_STR) - 1] - key_type: TypeEngine = parse_type(key_type_str) - value_type: TypeEngine = parse_type(value_type_str) - if isinstance(key_type, NullType) or isinstance(value_type, NullType): + value_type: TypeEngine = parse_type(parameter_str) + if isinstance(value_type, NullType): return None return { - "key_type": key_type, "value_type": value_type, - "not_null": not_null, + "not_null": is_not_null, } +def __parse_map_type_parameters(parameters): + if len(parameters) != 2: + return None + + key_type_str = parameters[0] + value_type_str = parameters[1] + key_type: TypeEngine = parse_type(key_type_str) + value_type = __parse_nullable_parameter([value_type_str]) + if isinstance(value_type, NullType) or isinstance(key_type, NullType): + return None + + return {"key_type": key_type, **value_type} + + def __parse_type_with_length_parameters(parameters): return ( {"length": int(parameters[0])} diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr index 714f5d57..b7050cca 100644 --- a/tests/__snapshots__/test_structured_datatypes.ambr +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -11,6 +11,20 @@ # name: test_compile_table_with_structured_data_type[structured_type1] 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))' # --- +# name: test_compile_table_with_structured_data_type[structured_type2] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname ARRAY(MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))' +# --- +# name: test_insert_array + list([ + (1, '[\n "item1",\n "item2"\n]'), + ]) +# --- +# name: test_insert_array_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(ARRAY_CONSTRUCT('item1', 'item2') AS ARRAY(VARCHAR(16777216)))] in VALUES clause + ''' +# --- # name: test_compile_table_with_structured_data_type[structured_type2] 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))' # --- @@ -166,6 +180,35 @@ }), ]) # --- +# name: test_inspect_structured_data_types[structured_type3-ARRAY] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': ARRAY(value_type=VARCHAR(length=16777216)), + }), + ]) +# --- +# name: test_reflect_structured_data_types[ARRAY(MAP(NUMBER(10, 0), VARCHAR))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col ARRAY(MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- # name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))] "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" # --- @@ -175,6 +218,12 @@ # name: test_reflect_structured_data_types[OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))] "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" # --- +# name: test_select_array_orm + list([ + (1, '[\n "item3",\n "item4"\n]'), + (2, '[\n "item1",\n "item2"\n]'), + ]) +# --- # name: test_select_map_orm list([ (1, '{\n "100": "item1",\n "200": "item2"\n}'), diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py index d6beb3e9..ce030bd2 100644 --- a/tests/test_structured_datatypes.py +++ b/tests/test_structured_datatypes.py @@ -18,7 +18,7 @@ from sqlalchemy.sql.ddl import CreateTable from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable -from snowflake.sqlalchemy.custom_types import MAP, OBJECT, TEXT +from snowflake.sqlalchemy.custom_types import ARRAY, MAP, OBJECT, TEXT from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError @@ -28,6 +28,7 @@ MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT(16777216))), OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), OBJECT(key1=TEXT(16777216), key2=NUMBER(10, 0)), + ARRAY(MAP(NUMBER(10, 0), TEXT(16777216))), ], ) def test_compile_table_with_structured_data_type( @@ -58,15 +59,6 @@ def test_insert_map(engine_testaccount, external_volume, base_location, snapshot external_volume=external_volume, base_location=base_location, ) - """ - Test inserting data into a table with a MAP column type. - - Args: - engine_testaccount: The SQLAlchemy engine connected to the test account. - external_volume: The external volume to use for the table. - base_location: The base location for the table. - snapshot: The snapshot object for assertion. - """ metadata.create_all(engine_testaccount) try: @@ -179,6 +171,128 @@ def __repr__(self): test_map.drop(engine_testaccount) +@pytest.mark.requires_external_volume +def test_select_array_orm(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_select_array_orm" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("array_col", ARRAY(TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + slt1 = select( + 2, + cast( + text("['item1','item2']"), + ARRAY(TEXT(16777216)), + ), + ) + slt2 = select( + 1, + cast( + text("['item3','item4']"), + ARRAY(TEXT(16777216)), + ), + ).union_all(slt1) + ins = test_map.insert().from_select(["id", "array_col"], slt2) + conn.execute(ins) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = test_map + + def __repr__(self): + return f"({self.id!r}, {self.array_col!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_array(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_insert_map" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("array_col", ARRAY(TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("['item1','item2']"), + ARRAY(TEXT(16777216)), + ), + ) + ins = test_map.insert().from_select(["id", "array_col"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_array_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + array_col = Column(ARRAY(TEXT(16777216))) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast(text("['item1','item2']"), ARRAY(TEXT(16777216))) + instance = TestIcebergTableOrm(id=0, array_col=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + @pytest.mark.requires_external_volume def test_insert_structured_object( engine_testaccount, external_volume, base_location, snapshot @@ -328,6 +442,7 @@ def __repr__(self): OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), OBJECT, ), + (ARRAY(TEXT(16777216)), ARRAY), ], ) def test_inspect_structured_data_types( @@ -369,6 +484,7 @@ def test_inspect_structured_data_types( "MAP(NUMBER(10, 0), VARCHAR)", "MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))", "OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))", + "ARRAY(MAP(NUMBER(10, 0), VARCHAR))", ], ) def test_reflect_structured_data_types( @@ -425,6 +541,10 @@ def test_create_table_structured_datatypes( "object_col", OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), ), + Column( + "array_col", + ARRAY(TEXT(16777216)), + ), external_volume=external_volume, base_location=base_location, ) @@ -443,6 +563,7 @@ def test_create_table_structured_datatypes( "object_col", OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), ), + Column("name", ARRAY(TEXT(16777216))), ], ) def test_structured_type_not_supported_in_table_columns_error( diff --git a/tests/test_unit_structured_types.py b/tests/test_unit_structured_types.py index 474ebde4..472ce2e6 100644 --- a/tests/test_unit_structured_types.py +++ b/tests/test_unit_structured_types.py @@ -69,6 +69,10 @@ def test_extract_parameters(): "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", ), ("ARRAY", "ARRAY"), + ( + "ARRAY(MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + "ARRAY(MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + ), ("GEOGRAPHY", "GEOGRAPHY"), ("GEOMETRY", "GEOMETRY"), ],