diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 4c632e7a..57479030 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -1098,9 +1098,17 @@ def visit_MAP(self, type_, **kw): def visit_ARRAY(self, type_, **kw): return "ARRAY" - def visit_OBJECT(self, type_, **kw): + def visit_OBJECT(self, type_, **kw): # TODO? return "OBJECT" + def visit_OBJECT_STRUCTURED(self, type_, **kw): + contents = [] + for key in type_.items_types: + contents.append( + f"{key} {type_.items_types[key][0].compile()} {'NOT NULL' if type_.items_types[key][1] else ''}" + ) + return "OBJECT" if contents == [] else f"OBJECT({', '.join(contents)})" + def visit_BLOB(self, type_, **kw): return "BINARY" diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index f2c950dd..d28b23ab 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -2,8 +2,11 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +from typing import Tuple, Union + import sqlalchemy.types as sqltypes import sqlalchemy.util as util +from sqlalchemy.types import TypeEngine TEXT = sqltypes.VARCHAR CHARACTER = sqltypes.CHAR @@ -57,6 +60,14 @@ def __init__( super().__init__() +class OBJECT_STRUCTURED(StructuredType): + __visit_name__ = "OBJECT_STRUCTURED" + + def __init__(self, **items_types: Union[TypeEngine, Tuple[TypeEngine, bool]]): + self.items_types = items_types + super().__init__() + + class OBJECT(SnowflakeType): __visit_name__ = "OBJECT" diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py index cf69c594..0dce0014 100644 --- a/src/snowflake/sqlalchemy/parser/custom_type_parser.py +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -1,5 +1,6 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +import re import sqlalchemy.types as sqltypes from sqlalchemy.sql.type_api import TypeEngine @@ -48,7 +49,7 @@ "DECIMAL": DECIMAL, "DOUBLE": DOUBLE, "FIXED": DECIMAL, - "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters + "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't have parameters "INT": INTEGER, "INTEGER": INTEGER, "NUMBER": _CUSTOM_DECIMAL, @@ -107,6 +108,14 @@ def extract_parameters(text: str) -> list: return output_parameters +def split_ignore_parentheses(text): + # This regex matches sequences of non-whitespace characters or + # entire groups of text within parentheses, ignoring spaces. + pattern = r"\s*(\([^)]*\)|[^\s()]+)\s*" + result = re.findall(pattern, text) + return [item for item in result if item] # Filter out empty strings + + def parse_type(type_text: str) -> TypeEngine: """ Parses a type definition string and returns the corresponding SQLAlchemy type. @@ -130,6 +139,7 @@ def parse_type(type_text: str) -> TypeEngine: col_type_class = ischema_names.get(type_name, None) col_type_kw = {} + if col_type_class is None: col_type_class = NullType else: @@ -139,6 +149,8 @@ def parse_type(type_text: str) -> TypeEngine: col_type_kw = __parse_type_with_length_parameters(parameters) elif issubclass(col_type_class, MAP): col_type_kw = __parse_map_type_parameters(parameters) + elif issubclass(col_type_class, OBJECT): + col_type_kw = __parse_object_type_parameters(parameters) if col_type_kw is None: col_type_class = NullType col_type_kw = {} @@ -146,6 +158,23 @@ def parse_type(type_text: str) -> TypeEngine: return col_type_class(**col_type_kw) +def __parse_object_type_parameters(parameters): + # Example of object type: OBJECT(key1 VARCHAR, key2 NUMBER NOT NULL) + # Parameters: [key1 VARCHAR, key2 NUMBER NOT NULL] + + object_rows = [] + for parameter in parameters: + parameter_parts = split_ignore_parentheses(parameter) + if len(parameter_parts) >= 2: + key = parameter_parts[0] + value_type = parse_type(parameter_parts[1]) + if isinstance(value_type, NullType): + return None + not_null = len(parameter_parts) == 3 and "NOT NULL" in parameter_parts[2] + object_rows.append((key, (value_type, not_null))) + return {"items_types": dict(object_rows)} + + def __parse_map_type_parameters(parameters): if len(parameters) != 2: return None diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py index 4ea0892b..5dbe7c86 100644 --- a/tests/test_structured_datatypes.py +++ b/tests/test_structured_datatypes.py @@ -19,7 +19,7 @@ from sqlalchemy.sql.ddl import CreateTable from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable -from snowflake.sqlalchemy.custom_types import MAP, TEXT +from snowflake.sqlalchemy.custom_types import MAP, OBJECT_STRUCTURED, TEXT from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError @@ -42,12 +42,15 @@ def test_create_table_structured_datatypes( engine_testaccount, external_volume, base_location ): metadata = MetaData() - table_name = "test_map0" + table_name = "test_structured0" test_map = IcebergTable( table_name, metadata, Column("id", Integer, primary_key=True), Column("map_id", MAP(NUMBER(10, 0), TEXT())), + Column( + "object_col", OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True)) + ), external_volume=external_volume, base_location=base_location, ) @@ -91,12 +94,54 @@ def test_insert_map(engine_testaccount, external_volume, base_location, snapshot test_map.drop(engine_testaccount) +@pytest.mark.requires_external_volume +def test_insert_structured_object( + engine_testaccount, external_volume, base_location, snapshot +): + metadata = MetaData() + table_name = "test_insert_structured_object" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column( + "object_col", OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True)) + ), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("{'key1':('item1', false), 'key2': (15, true)}"), + OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True)), + ), + ) + ins = test_map.insert().from_select(["id", "map_id"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + @pytest.mark.requires_external_volume @pytest.mark.parametrize( "structured_type", [ - MAP(NUMBER(10, 0), TEXT()), - MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT())), + (MAP(NUMBER(10, 0), TEXT()), MAP), + (MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT())), MAP), + ( + OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True)), + OBJECT_STRUCTURED, + ), ], ) def test_inspect_structured_data_types( @@ -108,7 +153,7 @@ def test_inspect_structured_data_types( table_name, metadata, Column("id", Integer, primary_key=True), - Column("map_id", structured_type), + Column("structured_type_col", structured_type[0]), external_volume=external_volume, base_location=base_location, ) @@ -119,7 +164,7 @@ def test_inspect_structured_data_types( columns = inspecter.get_columns(table_name) assert isinstance(columns[0]["type"], NUMBER) - assert isinstance(columns[1]["type"], MAP) + assert isinstance(columns[1]["type"], structured_type[1]) assert columns == snapshot finally: @@ -132,6 +177,7 @@ def test_inspect_structured_data_types( [ "MAP(NUMBER(10, 0), VARCHAR)", "MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))", + "OBJECT(key1 VARCHAR, key2 NUMBER NOT NULL)", ], ) def test_reflect_structured_data_types( @@ -147,7 +193,7 @@ def test_reflect_structured_data_types( create_table_sql = f""" CREATE OR REPLACE ICEBERG TABLE {table_name} ( id number(38,0) primary key, - map_id {structured_type}) + structured_type_col {structured_type}) CATALOG = 'SNOWFLAKE' EXTERNAL_VOLUME = '{external_volume}' BASE_LOCATION = '{base_location}'; @@ -224,6 +270,10 @@ def test_snowflake_tables_with_structured_types(sql_compiler): metadata, Column("Id", Integer, primary_key=True), Column("name", MAP(NUMBER(10, 0), TEXT())), + Column( + "object_col", + OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True)), + ), ) assert programming_error is not None diff --git a/tests/test_unit_structured_types.py b/tests/test_unit_structured_types.py index c7bcd6ef..72c09c8e 100644 --- a/tests/test_unit_structured_types.py +++ b/tests/test_unit_structured_types.py @@ -64,6 +64,10 @@ def test_extract_parameters(): ), ("MAP(DECIMAL(10, 0), VARIANT)", "MAP(DECIMAL(10, 0), VARIANT)"), ("OBJECT", "OBJECT"), + ( + "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", + "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", + ), ("ARRAY", "ARRAY"), ("GEOGRAPHY", "GEOGRAPHY"), ("GEOMETRY", "GEOMETRY"),