Skip to content

Commit

Permalink
SNOW-1776332 Add support for OBJECT
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-gvenegascastro committed Dec 9, 2024
1 parent 695c0a9 commit a4dc82d
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 9 deletions.
10 changes: 9 additions & 1 deletion src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,9 +1098,17 @@ def visit_MAP(self, type_, **kw):
def visit_ARRAY(self, type_, **kw):
return "ARRAY"

def visit_OBJECT(self, type_, **kw):
def visit_OBJECT(self, type_, **kw): # TODO?
return "OBJECT"

def visit_OBJECT_STRUCTURED(self, type_, **kw):
contents = []
for key in type_.items_types:
contents.append(
f"{key} {type_.items_types[key][0].compile()} {'NOT NULL' if type_.items_types[key][1] else ''}"
)
return "OBJECT" if contents == [] else f"OBJECT({', '.join(contents)})"

def visit_BLOB(self, type_, **kw):
return "BINARY"

Expand Down
11 changes: 11 additions & 0 deletions src/snowflake/sqlalchemy/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Tuple, Union

import sqlalchemy.types as sqltypes
import sqlalchemy.util as util
from sqlalchemy.types import TypeEngine

TEXT = sqltypes.VARCHAR
CHARACTER = sqltypes.CHAR
Expand Down Expand Up @@ -57,6 +60,14 @@ def __init__(
super().__init__()


class OBJECT_STRUCTURED(StructuredType):
__visit_name__ = "OBJECT_STRUCTURED"

def __init__(self, **items_types: Union[TypeEngine, Tuple[TypeEngine, bool]]):
self.items_types = items_types
super().__init__()


class OBJECT(SnowflakeType):
__visit_name__ = "OBJECT"

Expand Down
31 changes: 30 additions & 1 deletion src/snowflake/sqlalchemy/parser/custom_type_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
import re

import sqlalchemy.types as sqltypes
from sqlalchemy.sql.type_api import TypeEngine
Expand Down Expand Up @@ -48,7 +49,7 @@
"DECIMAL": DECIMAL,
"DOUBLE": DOUBLE,
"FIXED": DECIMAL,
"FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters
"FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't have parameters
"INT": INTEGER,
"INTEGER": INTEGER,
"NUMBER": _CUSTOM_DECIMAL,
Expand Down Expand Up @@ -107,6 +108,14 @@ def extract_parameters(text: str) -> list:
return output_parameters


def split_ignore_parentheses(text):
# This regex matches sequences of non-whitespace characters or
# entire groups of text within parentheses, ignoring spaces.
pattern = r"\s*(\([^)]*\)|[^\s()]+)\s*"
result = re.findall(pattern, text)
return [item for item in result if item] # Filter out empty strings


def parse_type(type_text: str) -> TypeEngine:
"""
Parses a type definition string and returns the corresponding SQLAlchemy type.
Expand All @@ -130,6 +139,7 @@ def parse_type(type_text: str) -> TypeEngine:

col_type_class = ischema_names.get(type_name, None)
col_type_kw = {}

if col_type_class is None:
col_type_class = NullType
else:
Expand All @@ -139,13 +149,32 @@ def parse_type(type_text: str) -> TypeEngine:
col_type_kw = __parse_type_with_length_parameters(parameters)
elif issubclass(col_type_class, MAP):
col_type_kw = __parse_map_type_parameters(parameters)
elif issubclass(col_type_class, OBJECT):
col_type_kw = __parse_object_type_parameters(parameters)
if col_type_kw is None:
col_type_class = NullType
col_type_kw = {}

return col_type_class(**col_type_kw)


def __parse_object_type_parameters(parameters):
# Example of object type: OBJECT(key1 VARCHAR, key2 NUMBER NOT NULL)
# Parameters: [key1 VARCHAR, key2 NUMBER NOT NULL]

object_rows = []
for parameter in parameters:
parameter_parts = split_ignore_parentheses(parameter)
if len(parameter_parts) >= 2:
key = parameter_parts[0]
value_type = parse_type(parameter_parts[1])
if isinstance(value_type, NullType):
return None
not_null = len(parameter_parts) == 3 and "NOT NULL" in parameter_parts[2]
object_rows.append((key, (value_type, not_null)))
return {"items_types": dict(object_rows)}


def __parse_map_type_parameters(parameters):
if len(parameters) != 2:
return None
Expand Down
64 changes: 57 additions & 7 deletions tests/test_structured_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sqlalchemy.sql.ddl import CreateTable

from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable
from snowflake.sqlalchemy.custom_types import MAP, TEXT
from snowflake.sqlalchemy.custom_types import MAP, OBJECT_STRUCTURED, TEXT
from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError


Expand All @@ -42,12 +42,15 @@ def test_create_table_structured_datatypes(
engine_testaccount, external_volume, base_location
):
metadata = MetaData()
table_name = "test_map0"
table_name = "test_structured0"
test_map = IcebergTable(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("map_id", MAP(NUMBER(10, 0), TEXT())),
Column(
"object_col", OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True))
),
external_volume=external_volume,
base_location=base_location,
)
Expand Down Expand Up @@ -91,12 +94,54 @@ def test_insert_map(engine_testaccount, external_volume, base_location, snapshot
test_map.drop(engine_testaccount)


@pytest.mark.requires_external_volume
def test_insert_structured_object(
engine_testaccount, external_volume, base_location, snapshot
):
metadata = MetaData()
table_name = "test_insert_structured_object"
test_map = IcebergTable(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column(
"object_col", OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True))
),
external_volume=external_volume,
base_location=base_location,
)
metadata.create_all(engine_testaccount)

try:
with engine_testaccount.connect() as conn:
slt = select(
1,
cast(
text("{'key1':('item1', false), 'key2': (15, true)}"),
OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True)),
),
)
ins = test_map.insert().from_select(["id", "map_id"], slt)
conn.execute(ins)

results = conn.execute(test_map.select())
data = results.fetchmany()
results.close()
snapshot.assert_match(data)
finally:
test_map.drop(engine_testaccount)


@pytest.mark.requires_external_volume
@pytest.mark.parametrize(
"structured_type",
[
MAP(NUMBER(10, 0), TEXT()),
MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT())),
(MAP(NUMBER(10, 0), TEXT()), MAP),
(MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT())), MAP),
(
OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True)),
OBJECT_STRUCTURED,
),
],
)
def test_inspect_structured_data_types(
Expand All @@ -108,7 +153,7 @@ def test_inspect_structured_data_types(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("map_id", structured_type),
Column("structured_type_col", structured_type[0]),
external_volume=external_volume,
base_location=base_location,
)
Expand All @@ -119,7 +164,7 @@ def test_inspect_structured_data_types(
columns = inspecter.get_columns(table_name)

assert isinstance(columns[0]["type"], NUMBER)
assert isinstance(columns[1]["type"], MAP)
assert isinstance(columns[1]["type"], structured_type[1])
assert columns == snapshot

finally:
Expand All @@ -132,6 +177,7 @@ def test_inspect_structured_data_types(
[
"MAP(NUMBER(10, 0), VARCHAR)",
"MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))",
"OBJECT(key1 VARCHAR, key2 NUMBER NOT NULL)",
],
)
def test_reflect_structured_data_types(
Expand All @@ -147,7 +193,7 @@ def test_reflect_structured_data_types(
create_table_sql = f"""
CREATE OR REPLACE ICEBERG TABLE {table_name} (
id number(38,0) primary key,
map_id {structured_type})
structured_type_col {structured_type})
CATALOG = 'SNOWFLAKE'
EXTERNAL_VOLUME = '{external_volume}'
BASE_LOCATION = '{base_location}';
Expand Down Expand Up @@ -224,6 +270,10 @@ def test_snowflake_tables_with_structured_types(sql_compiler):
metadata,
Column("Id", Integer, primary_key=True),
Column("name", MAP(NUMBER(10, 0), TEXT())),
Column(
"object_col",
OBJECT_STRUCTURED(key1=(TEXT(), False), key2=(NUMBER(), True)),
),
)
assert programming_error is not None

Expand Down
4 changes: 4 additions & 0 deletions tests/test_unit_structured_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def test_extract_parameters():
),
("MAP(DECIMAL(10, 0), VARIANT)", "MAP(DECIMAL(10, 0), VARIANT)"),
("OBJECT", "OBJECT"),
(
"OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)",
"OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)",
),
("ARRAY", "ARRAY"),
("GEOGRAPHY", "GEOGRAPHY"),
("GEOMETRY", "GEOMETRY"),
Expand Down

0 comments on commit a4dc82d

Please sign in to comment.