Skip to content

Commit

Permalink
Merge branch 'main' into NO-SNOW-update-readme-with-filing-cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jvasquezrojas authored Nov 21, 2024
2 parents 33424dd + 65754a4 commit 644bdb0
Show file tree
Hide file tree
Showing 24 changed files with 976 additions and 248 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @snowflakedb/snowcli
* @snowflakedb/ORM
4 changes: 4 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Source code is also available at:
# Release Notes

- (Unreleased)
- Add support for partition by to copy into <location>

- v1.7.0(November 22, 2024)

- Add support for dynamic tables and required options
- Add support for hybrid tables
Expand All @@ -18,6 +21,7 @@ Source code is also available at:
- Add support for refresh_mode option in DynamicTable
- Add support for iceberg table with Snowflake Catalog
- Fix cluster by option to support explicit expressions
- Add support for MAP datatype

- v1.6.1(July 9, 2024)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ line-length = 88
line-length = 88

[tool.pytest.ini_options]
addopts = "-m 'not feature_max_lob_size and not aws'"
addopts = "-m 'not feature_max_lob_size and not aws and not requires_external_volume'"
markers = [
# Optional dependency groups markers
"lambda: AWS lambda tests",
Expand All @@ -128,6 +128,7 @@ markers = [
# Other markers
"timeout: tests that need a timeout time",
"internal: tests that could but should only run on our internal CI",
"requires_external_volume: tests that needs a external volume to be executed",
"external: tests that could but should only run on our external CI",
"feature_max_lob_size: tests that could but should only run on our external CI",
]
2 changes: 2 additions & 0 deletions src/snowflake/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
FIXED,
GEOGRAPHY,
GEOMETRY,
MAP,
NUMBER,
OBJECT,
STRING,
Expand Down Expand Up @@ -119,6 +120,7 @@
"TINYINT",
"VARBINARY",
"VARIANT",
"MAP",
)

_custom_commands = (
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/sqlalchemy/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
APPLICATION_NAME = "SnowflakeSQLAlchemy"
SNOWFLAKE_SQLALCHEMY_VERSION = VERSION
DIALECT_NAME = "snowflake"
NOT_NULL = "NOT NULL"
33 changes: 27 additions & 6 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from sqlalchemy.schema import Sequence, Table
from sqlalchemy.sql import compiler, expression, functions
from sqlalchemy.sql.base import CompileState
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import BindParameter, quoted_name
from sqlalchemy.sql.expression import Executable
from sqlalchemy.sql.selectable import Lateral, SelectState

from snowflake.sqlalchemy._constants import DIALECT_NAME
Expand All @@ -27,6 +28,7 @@
ExternalStage,
)

from ._constants import NOT_NULL
from .exc import (
CustomOptionsAreOnlySupportedOnSnowflakeTables,
UnexpectedOptionTypeError,
Expand Down Expand Up @@ -562,9 +564,8 @@ def visit_copy_into(self, copy_into, **kw):
if isinstance(copy_into.into, Table)
else copy_into.into._compiler_dispatch(self, **kw)
)
from_ = None
if isinstance(copy_into.from_, Table):
from_ = copy_into.from_
from_ = copy_into.from_.name
# this is intended to catch AWSBucket and AzureContainer
elif (
isinstance(copy_into.from_, AWSBucket)
Expand All @@ -575,6 +576,21 @@ def visit_copy_into(self, copy_into, **kw):
# everything else (selects, etc.)
else:
from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})"

partition_by_value = None
if isinstance(copy_into.partition_by, (BindParameter, Executable)):
partition_by_value = copy_into.partition_by.compile(
compile_kwargs={"literal_binds": True}
)
elif copy_into.partition_by is not None:
partition_by_value = copy_into.partition_by

partition_by = (
f"PARTITION BY {partition_by_value}"
if partition_by_value is not None and partition_by_value != ""
else ""
)

credentials, encryption = "", ""
if isinstance(into, tuple):
into, credentials, encryption = into
Expand All @@ -585,8 +601,7 @@ def visit_copy_into(self, copy_into, **kw):
options_list.sort(key=operator.itemgetter(0))
options = (
(
" "
+ " ".join(
" ".join(
[
"{} = {}".format(
n,
Expand All @@ -607,7 +622,7 @@ def visit_copy_into(self, copy_into, **kw):
options += f" {credentials}"
if encryption:
options += f" {encryption}"
return f"COPY INTO {into} FROM {from_} {formatter}{options}"
return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}"

def visit_copy_formatter(self, formatter, **kw):
options_list = list(formatter.options.items())
Expand Down Expand Up @@ -1071,6 +1086,12 @@ def visit_TINYINT(self, type_, **kw):
def visit_VARIANT(self, type_, **kw):
return "VARIANT"

def visit_MAP(self, type_, **kw):
not_null = f" {NOT_NULL}" if type_.not_null else ""
return (
f"MAP({type_.key_type.compile()}, {type_.value_type.compile()}{not_null})"
)

def visit_ARRAY(self, type_, **kw):
return "ARRAY"

Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/sqlalchemy/custom_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,23 @@ class CopyInto(UpdateBase):
__visit_name__ = "copy_into"
_bind = None

def __init__(self, from_, into, formatter=None):
def __init__(self, from_, into, partition_by=None, formatter=None):
self.from_ = from_
self.into = into
self.formatter = formatter
self.copy_options = {}
self.partition_by = partition_by

def __repr__(self):
"""
repr for debugging / logging purposes only. For compilation logic, see
the corresponding visitor in base.py
"""
return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})"
val = f"COPY INTO {self.into} FROM {repr(self.from_)}"
if self.partition_by is not None:
val += f" PARTITION BY {self.partition_by}"

return val + f" {repr(self.formatter)} ({self.copy_options})"

def bind(self):
return None
Expand Down
20 changes: 20 additions & 0 deletions src/snowflake/sqlalchemy/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ class VARIANT(SnowflakeType):
__visit_name__ = "VARIANT"


class StructuredType(SnowflakeType):
def __init__(self):
super().__init__()


class MAP(StructuredType):
__visit_name__ = "MAP"

def __init__(
self,
key_type: sqltypes.TypeEngine,
value_type: sqltypes.TypeEngine,
not_null: bool = False,
):
self.key_type = key_type
self.value_type = value_type
self.not_null = not_null
super().__init__()


class OBJECT(SnowflakeType):
__visit_name__ = "OBJECT"

Expand Down
8 changes: 8 additions & 0 deletions src/snowflake/sqlalchemy/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,11 @@ def __init__(self, errors):

def __str__(self):
return "".join(str(e) for e in self.errors)


class StructuredTypeNotSupportedInTableColumnsError(ArgumentError):
def __init__(self, table_type: str, table_name: str, column_name: str):
super().__init__(
f"Column '{column_name}' is of a structured type, which is only supported on Iceberg tables. "
f"The table '{table_name}' is of type '{table_type}', not Iceberg."
)
190 changes: 190 additions & 0 deletions src/snowflake/sqlalchemy/parser/custom_type_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.

import sqlalchemy.types as sqltypes
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.types import (
BIGINT,
BINARY,
BOOLEAN,
CHAR,
DATE,
DATETIME,
DECIMAL,
FLOAT,
INTEGER,
REAL,
SMALLINT,
TIME,
TIMESTAMP,
VARCHAR,
NullType,
)

from ..custom_types import (
_CUSTOM_DECIMAL,
ARRAY,
DOUBLE,
GEOGRAPHY,
GEOMETRY,
MAP,
OBJECT,
TIMESTAMP_LTZ,
TIMESTAMP_NTZ,
TIMESTAMP_TZ,
VARIANT,
)

ischema_names = {
"BIGINT": BIGINT,
"BINARY": BINARY,
# 'BIT': BIT,
"BOOLEAN": BOOLEAN,
"CHAR": CHAR,
"CHARACTER": CHAR,
"DATE": DATE,
"DATETIME": DATETIME,
"DEC": DECIMAL,
"DECIMAL": DECIMAL,
"DOUBLE": DOUBLE,
"FIXED": DECIMAL,
"FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters
"INT": INTEGER,
"INTEGER": INTEGER,
"NUMBER": _CUSTOM_DECIMAL,
# 'OBJECT': ?
"REAL": REAL,
"BYTEINT": SMALLINT,
"SMALLINT": SMALLINT,
"STRING": VARCHAR,
"TEXT": VARCHAR,
"TIME": TIME,
"TIMESTAMP": TIMESTAMP,
"TIMESTAMP_TZ": TIMESTAMP_TZ,
"TIMESTAMP_LTZ": TIMESTAMP_LTZ,
"TIMESTAMP_NTZ": TIMESTAMP_NTZ,
"TINYINT": SMALLINT,
"VARBINARY": BINARY,
"VARCHAR": VARCHAR,
"VARIANT": VARIANT,
"MAP": MAP,
"OBJECT": OBJECT,
"ARRAY": ARRAY,
"GEOGRAPHY": GEOGRAPHY,
"GEOMETRY": GEOMETRY,
}


def extract_parameters(text: str) -> list:
"""
Extracts parameters from a comma-separated string, handling parentheses.
:param text: A string with comma-separated parameters, which may include parentheses.
:return: A list of parameters as strings.
:example:
For input `"a, (b, c), d"`, the output is `['a', '(b, c)', 'd']`.
"""

output_parameters = []
parameter = ""
open_parenthesis = 0
for c in text:

if c == "(":
open_parenthesis += 1
elif c == ")":
open_parenthesis -= 1

if open_parenthesis > 0 or c != ",":
parameter += c
elif c == ",":
output_parameters.append(parameter.strip(" "))
parameter = ""
if parameter != "":
output_parameters.append(parameter.strip(" "))
return output_parameters


def parse_type(type_text: str) -> TypeEngine:
"""
Parses a type definition string and returns the corresponding SQLAlchemy type.
The function handles types with or without parameters, such as `VARCHAR(255)` or `INTEGER`.
:param type_text: A string representing a SQLAlchemy type, which may include parameters
in parentheses (e.g., "VARCHAR(255)" or "DECIMAL(10, 2)").
:return: An instance of the corresponding SQLAlchemy type class (e.g., `String`, `Integer`),
or `NullType` if the type is not recognized.
:example:
parse_type("VARCHAR(255)")
String(length=255)
"""
index = type_text.find("(")
type_name = type_text[:index] if index != -1 else type_text
parameters = (
extract_parameters(type_text[index + 1 : -1]) if type_name != type_text else []
)

col_type_class = ischema_names.get(type_name, None)
col_type_kw = {}
if col_type_class is None:
col_type_class = NullType
else:
if issubclass(col_type_class, sqltypes.Numeric):
col_type_kw = __parse_numeric_type_parameters(parameters)
elif issubclass(col_type_class, (sqltypes.String, sqltypes.BINARY)):
col_type_kw = __parse_type_with_length_parameters(parameters)
elif issubclass(col_type_class, MAP):
col_type_kw = __parse_map_type_parameters(parameters)
if col_type_kw is None:
col_type_class = NullType
col_type_kw = {}

return col_type_class(**col_type_kw)


def __parse_map_type_parameters(parameters):
if len(parameters) != 2:
return None

key_type_str = parameters[0]
value_type_str = parameters[1]
not_null_str = "NOT NULL"
not_null = False
if (
len(value_type_str) >= len(not_null_str)
and value_type_str[-len(not_null_str) :] == not_null_str
):
not_null = True
value_type_str = value_type_str[: -len(not_null_str) - 1]

key_type: TypeEngine = parse_type(key_type_str)
value_type: TypeEngine = parse_type(value_type_str)
if isinstance(key_type, NullType) or isinstance(value_type, NullType):
return None

return {
"key_type": key_type,
"value_type": value_type,
"not_null": not_null,
}


def __parse_type_with_length_parameters(parameters):
return (
{"length": int(parameters[0])}
if len(parameters) == 1 and str.isdigit(parameters[0])
else {}
)


def __parse_numeric_type_parameters(parameters):
result = {}
if len(parameters) >= 1 and str.isdigit(parameters[0]):
result["precision"] = int(parameters[0])
if len(parameters) == 2 and str.isdigit(parameters[1]):
result["scale"] = int(parameters[1])
return result
Loading

0 comments on commit 644bdb0

Please sign in to comment.