Skip to content

Commit

Permalink
Add support for Structured ARRAY
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jvasquezrojas committed Dec 16, 2024
1 parent af5457a commit 5ed8555
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 32 deletions.
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,79 @@ data_object = json.loads(row[1])
data_array = json.loads(row[2])
```

### Structured Data Types Support

This module defines custom SQLAlchemy types for Snowflake structured data, specifically for **Iceberg tables**.
The types —**MAP**, **OBJECT**, and **ARRAY**— allow you to store complex data structures in your SQLAlchemy models.
For detailed information, refer to the Snowflake [Structured data types](https://docs.snowflake.com/en/sql-reference/data-types-structured) documentation.

---

#### MAP

The `MAP` type represents a collection of key-value pairs, where each key and value can have different types.

- **Key Type**: The type of the keys (e.g., `TEXT`, `NUMBER`).
- **Value Type**: The type of the values (e.g., `TEXT`, `NUMBER`).
- **Not Null**: Whether `NULL` values are allowed (default is `False`).

*Example Usage*

```python
IcebergTable(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("map_col", MAP(NUMBER(10, 0), TEXT(16777216))),
external_volume="external_volume",
base_location="base_location",
)
```

#### OBJECT

The `OBJECT` type represents a semi-structured object with named fields. Each field can have a specific type, and you can also specify whether each field is nullable.

- **Items Types**: A dictionary of field names and their types. The type can optionally include a nullable flag (`True` for not nullable, `False` for nullable, default is `False`).

*Example Usage*

```python
IcebergTable(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column(
"object_col",
OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)),
OBJECT(key1=TEXT(16777216), key2=NUMBER(10, 0)), # Without nullable flag
),
external_volume="external_volume",
base_location="base_location",
)
```

#### ARRAY

The `ARRAY` type represents an ordered list of values, where each element has the same type. The type of the elements is defined when creating the array.

- **Value Type**: The type of the elements in the array (e.g., `TEXT`, `NUMBER`).
- **Not Null**: Whether `NULL` values are allowed (default is `False`).

*Example Usage*

```python
IcebergTable(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("array_col", ARRAY(TEXT(16777216))),
external_volume="external_volume",
base_location="base_location",
)
```


### CLUSTER BY Support

Snowflake SQLAchemy supports the `CLUSTER BY` parameter for tables. For information about the parameter, see :doc:`/sql-reference/sql/create-table`.
Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,10 @@ def visit_MAP(self, type_, **kw):
)

def visit_ARRAY(self, type_, **kw):
return "ARRAY"
if type_.is_semi_structured:
return "ARRAY"
not_null = f" {NOT_NULL}" if type_.not_null else ""
return f"ARRAY({type_.value_type.compile()}{not_null})"

def visit_OBJECT(self, type_, **kw):
if type_.is_semi_structured:
Expand Down
16 changes: 13 additions & 3 deletions src/snowflake/sqlalchemy/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import sqlalchemy.types as sqltypes
import sqlalchemy.util as util
Expand Down Expand Up @@ -40,7 +40,8 @@ class VARIANT(SnowflakeType):


class StructuredType(SnowflakeType):
def __init__(self):
def __init__(self, is_semi_structured: bool = False):
self.is_semi_structured = is_semi_structured
super().__init__()


Expand Down Expand Up @@ -81,9 +82,18 @@ def __repr__(self):
)


class ARRAY(SnowflakeType):
class ARRAY(StructuredType):
__visit_name__ = "ARRAY"

def __init__(
self,
value_type: Optional[sqltypes.TypeEngine] = None,
not_null: bool = False,
):
self.value_type = value_type
self.not_null = not_null
super().__init__(is_semi_structured=value_type is None)


class TIMESTAMP_TZ(SnowflakeType):
__visit_name__ = "TIMESTAMP_TZ"
Expand Down
52 changes: 34 additions & 18 deletions src/snowflake/sqlalchemy/parser/custom_type_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
"GEOMETRY": GEOMETRY,
}

NOT_NULL_STR = "NOT NULL"


def tokenize_parameters(text: str, character_for_strip=",") -> list:
"""
Expand Down Expand Up @@ -160,6 +162,8 @@ def parse_type(type_text: str) -> TypeEngine:
col_type_kw = __parse_map_type_parameters(parameters)
elif issubclass(col_type_class, OBJECT):
col_type_kw = __parse_object_type_parameters(parameters)
elif issubclass(col_type_class, ARRAY):
col_type_kw = __parse_nullable_parameter(parameters)
if col_type_kw is None:
col_type_class = NullType
col_type_kw = {}
Expand All @@ -169,6 +173,7 @@ def parse_type(type_text: str) -> TypeEngine:

def __parse_object_type_parameters(parameters):
object_rows = {}
not_null_parts = NOT_NULL_STR.split(" ")
for parameter in parameters:
parameter_parts = tokenize_parameters(parameter, " ")
if len(parameter_parts) >= 2:
Expand All @@ -178,40 +183,51 @@ def __parse_object_type_parameters(parameters):
return None
not_null = (
len(parameter_parts) == 4
and parameter_parts[2] == "NOT"
and parameter_parts[3] == "NULL"
and parameter_parts[2] == not_null_parts[0]
and parameter_parts[3] == not_null_parts[1]
)
object_rows[key] = (value_type, not_null)
return object_rows


def __parse_map_type_parameters(parameters):
if len(parameters) != 2:
def __parse_nullable_parameter(parameters):
if len(parameters) < 1:
return {}
elif len(parameters) > 1:
return None

key_type_str = parameters[0]
value_type_str = parameters[1]
not_null_str = "NOT NULL"
not_null = False
parameter_str = parameters[0]
is_not_null = False
if (
len(value_type_str) >= len(not_null_str)
and value_type_str[-len(not_null_str) :] == not_null_str
len(parameter_str) >= len(NOT_NULL_STR)
and parameter_str[-len(NOT_NULL_STR) :] == NOT_NULL_STR
):
not_null = True
value_type_str = value_type_str[: -len(not_null_str) - 1]
is_not_null = True
parameter_str = parameter_str[: -len(NOT_NULL_STR) - 1]

key_type: TypeEngine = parse_type(key_type_str)
value_type: TypeEngine = parse_type(value_type_str)
if isinstance(key_type, NullType) or isinstance(value_type, NullType):
value_type: TypeEngine = parse_type(parameter_str)
if isinstance(value_type, NullType):
return None

return {
"key_type": key_type,
"value_type": value_type,
"not_null": not_null,
"not_null": is_not_null,
}


def __parse_map_type_parameters(parameters):
if len(parameters) != 2:
return None

key_type_str = parameters[0]
value_type_str = parameters[1]
key_type: TypeEngine = parse_type(key_type_str)
value_type = __parse_nullable_parameter([value_type_str])
if isinstance(value_type, NullType) or isinstance(key_type, NullType):
return None

return {"key_type": key_type, **value_type}


def __parse_type_with_length_parameters(parameters):
return (
{"length": int(parameters[0])}
Expand Down
49 changes: 49 additions & 0 deletions tests/__snapshots__/test_structured_datatypes.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@
# name: test_compile_table_with_structured_data_type[structured_type1]
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))'
# ---
# name: test_compile_table_with_structured_data_type[structured_type2]
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname ARRAY(MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))'
# ---
# name: test_insert_array
list([
(1, '[\n "item1",\n "item2"\n]'),
])
# ---
# name: test_insert_array_orm
'''
002014 (22000): SQL compilation error:
Invalid expression [CAST(ARRAY_CONSTRUCT('item1', 'item2') AS ARRAY(VARCHAR(16777216)))] in VALUES clause
'''
# ---
# name: test_compile_table_with_structured_data_type[structured_type2]
'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))'
# ---
Expand Down Expand Up @@ -166,6 +180,35 @@
}),
])
# ---
# name: test_inspect_structured_data_types[structured_type3-ARRAY]
list([
dict({
'autoincrement': True,
'comment': None,
'default': None,
'identity': dict({
'increment': 1,
'start': 1,
}),
'name': 'id',
'nullable': False,
'primary_key': True,
'type': _CUSTOM_DECIMAL(precision=10, scale=0),
}),
dict({
'autoincrement': False,
'comment': None,
'default': None,
'name': 'structured_type_col',
'nullable': True,
'primary_key': False,
'type': ARRAY(value_type=VARCHAR(length=16777216)),
}),
])
# ---
# name: test_reflect_structured_data_types[ARRAY(MAP(NUMBER(10, 0), VARCHAR))]
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col ARRAY(MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
# ---
# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))]
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
# ---
Expand All @@ -175,6 +218,12 @@
# name: test_reflect_structured_data_types[OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))]
"CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'"
# ---
# name: test_select_array_orm
list([
(1, '[\n "item3",\n "item4"\n]'),
(2, '[\n "item1",\n "item2"\n]'),
])
# ---
# name: test_select_map_orm
list([
(1, '{\n "100": "item1",\n "200": "item2"\n}'),
Expand Down
Loading

0 comments on commit 5ed8555

Please sign in to comment.