Skip to content

Commit

Permalink
SNOW-1776332 Add support for OBJECT (#559)
Browse files Browse the repository at this point in the history
* SNOW-1776332 Add support for OBJECT

* Updated description.md

* Add missing @pytest.mark.requires_external_volume

* Tuple validation in OBJECT class
  • Loading branch information
sfc-gh-gvenegascastro authored Dec 16, 2024
1 parent 6c43ada commit af5457a
Show file tree
Hide file tree
Showing 9 changed files with 489 additions and 119 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Source code is also available at:
- Fix quoting of `_` as column name
- Fix index columns was not being reflected
- Fix index reflection cache not working
- Add support for structured OBJECT datatype

- v1.7.1(December 02, 2024)
- Add support for partition by to copy into <location>
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ path = "src/snowflake/sqlalchemy/version.py"
development = [
"pre-commit",
"pytest",
"setuptools",
"pytest-cov",
"pytest-timeout",
"pytest-rerunfailures",
Expand Down Expand Up @@ -74,6 +75,8 @@ exclude = ["/.github"]
packages = ["src/snowflake"]

[tool.hatch.envs.default]
path = ".venv"
type = "virtual"
extra-dependencies = ["SQLAlchemy>=1.4.19,<2.1.0"]
features = ["development", "pandas"]
python = "3.8"
Expand Down
13 changes: 12 additions & 1 deletion src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,18 @@ def visit_ARRAY(self, type_, **kw):
return "ARRAY"

def visit_OBJECT(self, type_, **kw):
return "OBJECT"
if type_.is_semi_structured:
return "OBJECT"
else:
contents = []
for key in type_.items_types:

row_text = f"{key} {type_.items_types[key][0].compile()}"
# Type and not null is specified
if len(type_.items_types[key]) > 1:
row_text += f"{' NOT NULL' if type_.items_types[key][1] else ''}"
contents.append(row_text)
return "OBJECT" if contents == [] else f"OBJECT({', '.join(contents)})"

def visit_BLOB(self, type_, **kw):
return "BINARY"
Expand Down
22 changes: 21 additions & 1 deletion src/snowflake/sqlalchemy/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
from typing import Tuple, Union

import sqlalchemy.types as sqltypes
import sqlalchemy.util as util
from sqlalchemy.types import TypeEngine

TEXT = sqltypes.VARCHAR
CHARACTER = sqltypes.CHAR
Expand Down Expand Up @@ -57,9 +59,27 @@ def __init__(
super().__init__()


class OBJECT(SnowflakeType):
class OBJECT(StructuredType):
__visit_name__ = "OBJECT"

def __init__(self, **items_types: Union[TypeEngine, Tuple[TypeEngine, bool]]):
for key, value in items_types.items():
if not isinstance(value, tuple):
items_types[key] = (value, False)

self.items_types = items_types
self.is_semi_structured = len(items_types) == 0
super().__init__()

def __repr__(self):
quote_char = "'"
return "OBJECT(%s)" % ", ".join(
[
f"{repr(key).strip(quote_char)}={repr(value)}"
for key, value in self.items_types.items()
]
)


class ARRAY(SnowflakeType):
__visit_name__ = "ARRAY"
Expand Down
37 changes: 30 additions & 7 deletions src/snowflake/sqlalchemy/parser/custom_type_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@
"DECIMAL": DECIMAL,
"DOUBLE": DOUBLE,
"FIXED": DECIMAL,
"FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters
"FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't have parameters
"INT": INTEGER,
"INTEGER": INTEGER,
"NUMBER": _CUSTOM_DECIMAL,
# 'OBJECT': ?
"REAL": REAL,
"BYTEINT": SMALLINT,
"SMALLINT": SMALLINT,
Expand All @@ -76,18 +75,19 @@
}


def extract_parameters(text: str) -> list:
def tokenize_parameters(text: str, character_for_strip=",") -> list:
"""
Extracts parameters from a comma-separated string, handling parentheses.
:param text: A string with comma-separated parameters, which may include parentheses.
:param character_for_strip: A character to strip the text.
:return: A list of parameters as strings.
:example:
For input `"a, (b, c), d"`, the output is `['a', '(b, c)', 'd']`.
"""

output_parameters = []
parameter = ""
open_parenthesis = 0
Expand All @@ -98,9 +98,9 @@ def extract_parameters(text: str) -> list:
elif c == ")":
open_parenthesis -= 1

if open_parenthesis > 0 or c != ",":
if open_parenthesis > 0 or c != character_for_strip:
parameter += c
elif c == ",":
elif c == character_for_strip:
output_parameters.append(parameter.strip(" "))
parameter = ""
if parameter != "":
Expand Down Expand Up @@ -138,14 +138,17 @@ def parse_type(type_text: str) -> TypeEngine:
parse_type("VARCHAR(255)")
String(length=255)
"""

index = type_text.find("(")
type_name = type_text[:index] if index != -1 else type_text

parameters = (
extract_parameters(type_text[index + 1 : -1]) if type_name != type_text else []
tokenize_parameters(type_text[index + 1 : -1]) if type_name != type_text else []
)

col_type_class = ischema_names.get(type_name, None)
col_type_kw = {}

if col_type_class is None:
col_type_class = NullType
else:
Expand All @@ -155,13 +158,33 @@ def parse_type(type_text: str) -> TypeEngine:
col_type_kw = __parse_type_with_length_parameters(parameters)
elif issubclass(col_type_class, MAP):
col_type_kw = __parse_map_type_parameters(parameters)
elif issubclass(col_type_class, OBJECT):
col_type_kw = __parse_object_type_parameters(parameters)
if col_type_kw is None:
col_type_class = NullType
col_type_kw = {}

return col_type_class(**col_type_kw)


def __parse_object_type_parameters(parameters):
object_rows = {}
for parameter in parameters:
parameter_parts = tokenize_parameters(parameter, " ")
if len(parameter_parts) >= 2:
key = parameter_parts[0]
value_type = parse_type(parameter_parts[1])
if isinstance(value_type, NullType):
return None
not_null = (
len(parameter_parts) == 4
and parameter_parts[2] == "NOT"
and parameter_parts[3] == "NULL"
)
object_rows[key] = (value_type, not_null)
return object_rows


def __parse_map_type_parameters(parameters):
if len(parameters) != 2:
return None
Expand Down
21 changes: 16 additions & 5 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import operator
import re
from collections import defaultdict
from functools import reduce
from typing import Any, Collection, Optional
from urllib.parse import unquote_plus

import sqlalchemy.types as sqltypes
import sqlalchemy.sql.sqltypes as sqltypes
from sqlalchemy import event as sa_vnt
from sqlalchemy import exc as sa_exc
from sqlalchemy import util as sa_util
from sqlalchemy.engine import URL, default, reflection
from sqlalchemy.schema import Table
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.types import FLOAT, Date, DateTime, Float, NullType, Time
from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.types import FLOAT, Date, DateTime, Float, Time

from snowflake.connector import errors as sf_errors
from snowflake.connector.connection import DEFAULT_CONFIGURATION
Expand All @@ -33,7 +33,7 @@
SnowflakeTypeCompiler,
)
from .custom_types import (
MAP,
StructuredType,
_CUSTOM_Date,
_CUSTOM_DateTime,
_CUSTOM_Float,
Expand Down Expand Up @@ -466,6 +466,14 @@ def _get_schema_columns(self, connection, schema, **kw):
connection, full_schema_name, **kw
)
schema_name = self.denormalize_name(schema)

iceberg_table_names = self.get_table_names_with_prefix(
connection,
schema=schema_name,
prefix=CustomTablePrefix.ICEBERG.name,
info_cache=kw.get("info_cache", None),
)

result = connection.execute(
text(
"""
Expand Down Expand Up @@ -526,7 +534,10 @@ def _get_schema_columns(self, connection, schema, **kw):
col_type_kw["scale"] = numeric_scale
elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
col_type_kw["length"] = character_maximum_length
elif issubclass(col_type, MAP):
elif (
issubclass(col_type, StructuredType)
and table_name in iceberg_table_names
):
if (schema_name, table_name) not in full_columns_descriptions:
full_columns_descriptions[(schema_name, table_name)] = (
self.table_columns_as_dict(
Expand Down
111 changes: 109 additions & 2 deletions tests/__snapshots__/test_structured_datatypes.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
# name: test_compile_table_with_double_map
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))'
# ---
# name: test_compile_table_with_structured_data_type[structured_type0]
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))'
# ---
# name: test_compile_table_with_structured_data_type[structured_type1]
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))'
# ---
# name: test_compile_table_with_structured_data_type[structured_type2]
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))'
# ---
# name: test_insert_map
list([
(1, '{\n "100": "item1",\n "200": "item2"\n}'),
Expand All @@ -16,6 +25,43 @@
Invalid expression [CAST(OBJECT_CONSTRUCT('100', 'item1', '200', 'item2') AS MAP(NUMBER(10,0), VARCHAR(16777216)))] in VALUES clause
'''
# ---
# name: test_insert_structured_object
list([
(1, '{\n "key1": "item1",\n "key2": 15\n}'),
])
# ---
# name: test_insert_structured_object_orm
'''
002014 (22000): SQL compilation error:
Invalid expression [CAST(OBJECT_CONSTRUCT('key1', 1, 'key2', 'item1') AS OBJECT(key1 NUMBER(10,0), key2 VARCHAR(16777216)))] in VALUES clause
'''
# ---
# name: test_inspect_structured_data_types[structured_type0-MAP]
list([
dict({
'autoincrement': True,
'comment': None,
'default': None,
'identity': dict({
'increment': 1,
'start': 1,
}),
'name': 'id',
'nullable': False,
'primary_key': True,
'type': _CUSTOM_DECIMAL(precision=10, scale=0),
}),
dict({
'autoincrement': False,
'comment': None,
'default': None,
'name': 'structured_type_col',
'nullable': True,
'primary_key': False,
'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216)),
}),
])
# ---
# name: test_inspect_structured_data_types[structured_type0]
list([
dict({
Expand All @@ -42,6 +88,32 @@
}),
])
# ---
# name: test_inspect_structured_data_types[structured_type1-MAP]
list([
dict({
'autoincrement': True,
'comment': None,
'default': None,
'identity': dict({
'increment': 1,
'start': 1,
}),
'name': 'id',
'nullable': False,
'primary_key': True,
'type': _CUSTOM_DECIMAL(precision=10, scale=0),
}),
dict({
'autoincrement': False,
'comment': None,
'default': None,
'name': 'structured_type_col',
'nullable': True,
'primary_key': False,
'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216))),
}),
])
# ---
# name: test_inspect_structured_data_types[structured_type1]
list([
dict({
Expand All @@ -68,11 +140,40 @@
}),
])
# ---
# name: test_inspect_structured_data_types[structured_type2-OBJECT]
list([
dict({
'autoincrement': True,
'comment': None,
'default': None,
'identity': dict({
'increment': 1,
'start': 1,
}),
'name': 'id',
'nullable': False,
'primary_key': True,
'type': _CUSTOM_DECIMAL(precision=10, scale=0),
}),
dict({
'autoincrement': False,
'comment': None,
'default': None,
'name': 'structured_type_col',
'nullable': True,
'primary_key': False,
'type': OBJECT(key1=(VARCHAR(length=16777216), False), key2=(_CUSTOM_DECIMAL(precision=10, scale=0), False)),
}),
])
# ---
# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))]
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
# ---
# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), VARCHAR)]
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
# ---
# name: test_reflect_structured_data_types[OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))]
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
# ---
# name: test_select_map_orm
list([
Expand All @@ -88,3 +189,9 @@
list([
])
# ---
# name: test_select_structured_object_orm
list([
(1, '{\n "key1": "value2",\n "key2": 2\n}'),
(2, '{\n "key1": "value1",\n "key2": 1\n}'),
])
# ---
Loading

0 comments on commit af5457a

Please sign in to comment.