Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1776332 Add support for OBJECT #559

Merged
merged 5 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Source code is also available at:
- Fix quoting of `_` as column name
- Fix index columns was not being reflected
- Fix index reflection cache not working
- Add support for structured OBJECT datatype

- v1.7.1(December 02, 2024)
- Add support for partition by to copy into <location>
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ path = "src/snowflake/sqlalchemy/version.py"
development = [
"pre-commit",
"pytest",
"setuptools",
"pytest-cov",
"pytest-timeout",
"pytest-rerunfailures",
Expand Down Expand Up @@ -74,6 +75,8 @@ exclude = ["/.github"]
packages = ["src/snowflake"]

[tool.hatch.envs.default]
path = ".venv"
type = "virtual"
extra-dependencies = ["SQLAlchemy>=1.4.19,<2.1.0"]
features = ["development", "pandas"]
python = "3.8"
Expand Down
13 changes: 12 additions & 1 deletion src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,18 @@ def visit_ARRAY(self, type_, **kw):
return "ARRAY"

def visit_OBJECT(self, type_, **kw):
return "OBJECT"
if type_.is_semi_structured:
return "OBJECT"
else:
contents = []
for key in type_.items_types:

row_text = f"{key} {type_.items_types[key][0].compile()}"
# Type and not null is specified
if len(type_.items_types[key]) > 1:
row_text += f"{' NOT NULL' if type_.items_types[key][1] else ''}"
contents.append(row_text)
return "OBJECT" if contents == [] else f"OBJECT({', '.join(contents)})"

def visit_BLOB(self, type_, **kw):
return "BINARY"
Expand Down
22 changes: 21 additions & 1 deletion src/snowflake/sqlalchemy/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
from typing import Tuple, Union

import sqlalchemy.types as sqltypes
import sqlalchemy.util as util
from sqlalchemy.types import TypeEngine

TEXT = sqltypes.VARCHAR
CHARACTER = sqltypes.CHAR
Expand Down Expand Up @@ -57,9 +59,27 @@ def __init__(
super().__init__()


class OBJECT(SnowflakeType):
class OBJECT(StructuredType):
__visit_name__ = "OBJECT"

def __init__(self, **items_types: Union[TypeEngine, Tuple[TypeEngine, bool]]):
for key, value in items_types.items():
if not isinstance(value, tuple):
items_types[key] = (value, False)

self.items_types = items_types
self.is_semi_structured = len(items_types) == 0
super().__init__()

def __repr__(self):
quote_char = "'"
return "OBJECT(%s)" % ", ".join(
[
f"{repr(key).strip(quote_char)}={repr(value)}"
for key, value in self.items_types.items()
]
)


class ARRAY(SnowflakeType):
__visit_name__ = "ARRAY"
Expand Down
37 changes: 30 additions & 7 deletions src/snowflake/sqlalchemy/parser/custom_type_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@
"DECIMAL": DECIMAL,
"DOUBLE": DOUBLE,
"FIXED": DECIMAL,
"FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters
"FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't have parameters
"INT": INTEGER,
"INTEGER": INTEGER,
"NUMBER": _CUSTOM_DECIMAL,
# 'OBJECT': ?
"REAL": REAL,
"BYTEINT": SMALLINT,
"SMALLINT": SMALLINT,
Expand All @@ -76,18 +75,19 @@
}


def extract_parameters(text: str) -> list:
def tokenize_parameters(text: str, character_for_strip=",") -> list:
"""
Extracts parameters from a comma-separated string, handling parentheses.

:param text: A string with comma-separated parameters, which may include parentheses.

:param character_for_strip: A character to strip the text.

:return: A list of parameters as strings.

:example:
For input `"a, (b, c), d"`, the output is `['a', '(b, c)', 'd']`.
"""

output_parameters = []
parameter = ""
open_parenthesis = 0
Expand All @@ -98,9 +98,9 @@ def extract_parameters(text: str) -> list:
elif c == ")":
open_parenthesis -= 1

if open_parenthesis > 0 or c != ",":
if open_parenthesis > 0 or c != character_for_strip:
parameter += c
elif c == ",":
elif c == character_for_strip:
output_parameters.append(parameter.strip(" "))
parameter = ""
if parameter != "":
Expand Down Expand Up @@ -138,14 +138,17 @@ def parse_type(type_text: str) -> TypeEngine:
parse_type("VARCHAR(255)")
String(length=255)
"""

index = type_text.find("(")
type_name = type_text[:index] if index != -1 else type_text

parameters = (
extract_parameters(type_text[index + 1 : -1]) if type_name != type_text else []
tokenize_parameters(type_text[index + 1 : -1]) if type_name != type_text else []
)

col_type_class = ischema_names.get(type_name, None)
col_type_kw = {}

if col_type_class is None:
col_type_class = NullType
else:
Expand All @@ -155,13 +158,33 @@ def parse_type(type_text: str) -> TypeEngine:
col_type_kw = __parse_type_with_length_parameters(parameters)
elif issubclass(col_type_class, MAP):
col_type_kw = __parse_map_type_parameters(parameters)
elif issubclass(col_type_class, OBJECT):
col_type_kw = __parse_object_type_parameters(parameters)
if col_type_kw is None:
col_type_class = NullType
col_type_kw = {}

return col_type_class(**col_type_kw)


def __parse_object_type_parameters(parameters):
object_rows = {}
for parameter in parameters:
parameter_parts = tokenize_parameters(parameter, " ")
if len(parameter_parts) >= 2:
key = parameter_parts[0]
value_type = parse_type(parameter_parts[1])
if isinstance(value_type, NullType):
return None
not_null = (
len(parameter_parts) == 4
and parameter_parts[2] == "NOT"
and parameter_parts[3] == "NULL"
)
object_rows[key] = (value_type, not_null)
return object_rows


def __parse_map_type_parameters(parameters):
if len(parameters) != 2:
return None
Expand Down
21 changes: 16 additions & 5 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import operator
import re
from collections import defaultdict
from functools import reduce
from typing import Any, Collection, Optional
from urllib.parse import unquote_plus

import sqlalchemy.types as sqltypes
import sqlalchemy.sql.sqltypes as sqltypes
from sqlalchemy import event as sa_vnt
from sqlalchemy import exc as sa_exc
from sqlalchemy import util as sa_util
from sqlalchemy.engine import URL, default, reflection
from sqlalchemy.schema import Table
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.types import FLOAT, Date, DateTime, Float, NullType, Time
from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.types import FLOAT, Date, DateTime, Float, Time

from snowflake.connector import errors as sf_errors
from snowflake.connector.connection import DEFAULT_CONFIGURATION
Expand All @@ -33,7 +33,7 @@
SnowflakeTypeCompiler,
)
from .custom_types import (
MAP,
StructuredType,
_CUSTOM_Date,
_CUSTOM_DateTime,
_CUSTOM_Float,
Expand Down Expand Up @@ -466,6 +466,14 @@ def _get_schema_columns(self, connection, schema, **kw):
connection, full_schema_name, **kw
)
schema_name = self.denormalize_name(schema)

iceberg_table_names = self.get_table_names_with_prefix(
connection,
schema=schema_name,
prefix=CustomTablePrefix.ICEBERG.name,
info_cache=kw.get("info_cache", None),
)

result = connection.execute(
text(
"""
Expand Down Expand Up @@ -526,7 +534,10 @@ def _get_schema_columns(self, connection, schema, **kw):
col_type_kw["scale"] = numeric_scale
elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
col_type_kw["length"] = character_maximum_length
elif issubclass(col_type, MAP):
elif (
issubclass(col_type, StructuredType)
and table_name in iceberg_table_names
):
if (schema_name, table_name) not in full_columns_descriptions:
full_columns_descriptions[(schema_name, table_name)] = (
self.table_columns_as_dict(
Expand Down
111 changes: 109 additions & 2 deletions tests/__snapshots__/test_structured_datatypes.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
# name: test_compile_table_with_double_map
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))'
# ---
# name: test_compile_table_with_structured_data_type[structured_type0]
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))'
# ---
# name: test_compile_table_with_structured_data_type[structured_type1]
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))'
# ---
# name: test_compile_table_with_structured_data_type[structured_type2]
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))'
# ---
# name: test_insert_map
list([
(1, '{\n "100": "item1",\n "200": "item2"\n}'),
Expand All @@ -16,6 +25,43 @@
Invalid expression [CAST(OBJECT_CONSTRUCT('100', 'item1', '200', 'item2') AS MAP(NUMBER(10,0), VARCHAR(16777216)))] in VALUES clause
'''
# ---
# name: test_insert_structured_object
list([
(1, '{\n "key1": "item1",\n "key2": 15\n}'),
])
# ---
# name: test_insert_structured_object_orm
'''
002014 (22000): SQL compilation error:
Invalid expression [CAST(OBJECT_CONSTRUCT('key1', 1, 'key2', 'item1') AS OBJECT(key1 NUMBER(10,0), key2 VARCHAR(16777216)))] in VALUES clause
'''
# ---
# name: test_inspect_structured_data_types[structured_type0-MAP]
list([
dict({
'autoincrement': True,
'comment': None,
'default': None,
'identity': dict({
'increment': 1,
'start': 1,
}),
'name': 'id',
'nullable': False,
'primary_key': True,
'type': _CUSTOM_DECIMAL(precision=10, scale=0),
}),
dict({
'autoincrement': False,
'comment': None,
'default': None,
'name': 'structured_type_col',
'nullable': True,
'primary_key': False,
'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216)),
}),
])
# ---
# name: test_inspect_structured_data_types[structured_type0]
list([
dict({
Expand All @@ -42,6 +88,32 @@
}),
])
# ---
# name: test_inspect_structured_data_types[structured_type1-MAP]
list([
dict({
'autoincrement': True,
'comment': None,
'default': None,
'identity': dict({
'increment': 1,
'start': 1,
}),
'name': 'id',
'nullable': False,
'primary_key': True,
'type': _CUSTOM_DECIMAL(precision=10, scale=0),
}),
dict({
'autoincrement': False,
'comment': None,
'default': None,
'name': 'structured_type_col',
'nullable': True,
'primary_key': False,
'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216))),
}),
])
# ---
# name: test_inspect_structured_data_types[structured_type1]
list([
dict({
Expand All @@ -68,11 +140,40 @@
}),
])
# ---
# name: test_inspect_structured_data_types[structured_type2-OBJECT]
list([
dict({
'autoincrement': True,
'comment': None,
'default': None,
'identity': dict({
'increment': 1,
'start': 1,
}),
'name': 'id',
'nullable': False,
'primary_key': True,
'type': _CUSTOM_DECIMAL(precision=10, scale=0),
}),
dict({
'autoincrement': False,
'comment': None,
'default': None,
'name': 'structured_type_col',
'nullable': True,
'primary_key': False,
'type': OBJECT(key1=(VARCHAR(length=16777216), False), key2=(_CUSTOM_DECIMAL(precision=10, scale=0), False)),
}),
])
# ---
# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))]
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
# ---
# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), VARCHAR)]
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
# ---
# name: test_reflect_structured_data_types[OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))]
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
# ---
# name: test_select_map_orm
list([
Expand All @@ -88,3 +189,9 @@
list([
])
# ---
# name: test_select_structured_object_orm
list([
(1, '{\n "key1": "value2",\n "key2": 2\n}'),
(2, '{\n "key1": "value1",\n "key2": 1\n}'),
])
# ---
Loading
Loading