diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 67b50ab0..205685f1 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,10 @@ Source code is also available at: # Release Notes -- 1.6.2 +- (Unreleased) + + - Add support for dynamic tables and required options + - Fixed SAWarning when registering functions with existing name in default namespace - Fixed SAWarning when registering functions with existing name in default namespace - v1.6.1(July 9, 2024) diff --git a/pyproject.toml b/pyproject.toml index 99aacbee..4fe06a9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ development = [ "pytz", "numpy", "mock", + "syrupy==4.6.1", ] pandas = ["snowflake-connector-python[pandas]"] diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 9df6aaa2..30cd140c 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -61,6 +61,8 @@ VARBINARY, VARIANT, ) +from .sql.custom_schema import DynamicTable +from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse from .util import _url as URL base.dialect = dialect = snowdialect.dialect @@ -113,4 +115,9 @@ "ExternalStage", "CreateStage", "CreateFileFormat", + "DynamicTable", + "AsQuery", + "TargetLag", + "TimeUnit", + "Warehouse", ) diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 46af4454..839745ee 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -10,3 +10,4 @@ APPLICATION_NAME = "SnowflakeSQLAlchemy" SNOWFLAKE_SQLALCHEMY_VERSION = VERSION +DIALECT_NAME = "snowflake" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 3e504f7b..56631728 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -18,9 +18,16 @@ from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.selectable import Lateral, SelectState -from .compat import IS_VERSION_20, args_reducer, string_types -from .custom_commands import AWSBucket, AzureContainer, ExternalStage +from snowflake.sqlalchemy._constants import DIALECT_NAME +from snowflake.sqlalchemy.compat import IS_VERSION_20, args_reducer, string_types +from snowflake.sqlalchemy.custom_commands import ( + AWSBucket, + AzureContainer, + ExternalStage, +) + from .functions import flatten +from .sql.custom_schema.options.table_option_base import TableOptionBase from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -878,7 +885,7 @@ def get_column_specification(self, column, **kwargs): return " ".join(colspec) - def post_create_table(self, table): + def handle_cluster_by(self, table): """ Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax. @@ -908,7 +915,7 @@ def post_create_table(self, table): """ text = "" - info = table.dialect_options["snowflake"] + info = table.dialect_options[DIALECT_NAME] cluster = info.get("clusterby") if cluster: text += " CLUSTER BY ({})".format( @@ -916,6 +923,21 @@ def post_create_table(self, table): ) return text + def post_create_table(self, table): + text = self.handle_cluster_by(table) + options = [ + option + for _, option in table.dialect_options[DIALECT_NAME].items() + if isinstance(option, TableOptionBase) + ] + options.sort( + key=lambda x: (x.__priority__.value, x.__option_name__), reverse=True + ) + for option in options: + text += "\t" + option.render_option(self) + + return text + def visit_create_stage(self, create_stage, **kw): """ This visitor will create the SQL representation for a CREATE STAGE command. diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 04305a00..b0472eb6 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -42,6 +42,7 @@ from snowflake.connector.constants import UTF8 from snowflake.sqlalchemy.compat import returns_unicode +from ._constants import DIALECT_NAME from .base import ( SnowflakeCompiler, SnowflakeDDLCompiler, @@ -119,7 +120,7 @@ class SnowflakeDialect(default.DefaultDialect): - name = "snowflake" + name = DIALECT_NAME driver = "snowflake" max_identifier_length = 255 cte_follows_insert = True diff --git a/src/snowflake/sqlalchemy/sql/__init__.py b/src/snowflake/sqlalchemy/sql/__init__.py new file mode 100644 index 00000000..ef416f64 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py new file mode 100644 index 00000000..4bbac246 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from .dynamic_table import DynamicTable + +__all__ = ["DynamicTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py new file mode 100644 index 00000000..0c04f33f --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem, Table + +from ..._constants import DIALECT_NAME +from ...compat import IS_VERSION_20 +from ...custom_commands import NoneType +from .options.table_option import TableOption + + +class CustomTableBase(Table): + __table_prefix__ = "" + _support_primary_and_foreign_keys = True + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if self.__table_prefix__ != "": + prefixes = kw.get("prefixes", []) + [self.__table_prefix__] + kw.update(prefixes=prefixes) + if not IS_VERSION_20 and hasattr(super(), "_init"): + super()._init(name, metadata, *args, **kw) + else: + super().__init__(name, metadata, *args, **kw) + + if not kw.get("autoload_with", False): + self._validate_table() + + def _validate_table(self): + if not self._support_primary_and_foreign_keys and ( + self.primary_key or self.foreign_keys + ): + raise ArgumentError( + f"Primary key and foreign keys are not supported in {self.__table_prefix__} TABLE." + ) + + return True + + def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: + if option_name in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name] + return NoneType diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py new file mode 100644 index 00000000..7d0a02e6 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .options.target_lag import TargetLag +from .options.warehouse import Warehouse +from .table_from_query import TableFromQueryBase + + +class DynamicTable(TableFromQueryBase): + """ + A class representing a dynamic table with configurable options and settings. + + The `DynamicTable` class allows for the creation and querying of tables with + specific options, such as `Warehouse` and `TargetLag`. + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + + """ + + __table_prefix__ = "DYNAMIC" + + _support_primary_and_foreign_keys = False + + @property + def warehouse(self) -> typing.Optional[Warehouse]: + return self._get_dialect_option(Warehouse.__option_name__) + + @property + def target_lag(self) -> typing.Optional[TargetLag]: + return self._get_dialect_option(TargetLag.__option_name__) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + super().__init__(name, metadata, *args, **kw) + + def _validate_table(self): + missing_attributes = [] + if self.target_lag is NoneType: + missing_attributes.append("TargetLag") + if self.warehouse is NoneType: + missing_attributes.append("Warehouse") + if self.as_query is NoneType: + missing_attributes.append("AsQuery") + if missing_attributes: + raise ArgumentError( + "DYNAMIC TABLE must have the following arguments: %s" + % ", ".join(missing_attributes) + ) + super()._validate_table() + + def __repr__(self) -> str: + return "DynamicTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.target_lag)] + + [repr(self.warehouse)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py new file mode 100644 index 00000000..052e2d96 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -0,0 +1,9 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from .as_query import AsQuery +from .target_lag import TargetLag, TimeUnit +from .warehouse import Warehouse + +__all__ = ["Warehouse", "AsQuery", "TargetLag", "TimeUnit"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py new file mode 100644 index 00000000..68076af9 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Union + +from sqlalchemy.sql import Selectable + +from .table_option import TableOption +from .table_option_base import Priority + + +class AsQuery(TableOption): + """Class to represent an AS clause in tables. + This configuration option is used to specify the query from which the table is created. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + + AsQuery example usage using an input string: + DynamicTable( + "sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + AsQuery('select name, address from existing_table where name = "test"') + ) + + AsQuery example usage using a selectable statement: + DynamicTable( + "sometable", + Base.metadata, + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery(select(test_table_1).where(test_table_1.c.id == 23)) + ) + + """ + + __option_name__ = "as_query" + __priority__ = Priority.LOWEST + + def __init__(self, query: Union[str, Selectable]) -> None: + r"""Construct an as_query object. + + :param \*expressions: + AS + + """ + self.query = query + + @staticmethod + def template() -> str: + return "AS %s" + + def get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def render_option(self, compiler) -> str: + return AsQuery.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "Query(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py new file mode 100644 index 00000000..7ac27575 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any + +from sqlalchemy import exc +from sqlalchemy.sql.base import SchemaEventTarget +from sqlalchemy.sql.schema import SchemaItem, Table + +from snowflake.sqlalchemy._constants import DIALECT_NAME + +from .table_option_base import TableOptionBase + + +class TableOption(TableOptionBase, SchemaItem): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + if self.__option_name__ == "default": + raise exc.SQLAlchemyError(f"{self.__class__.__name__} does not has a name") + if not isinstance(parent, Table): + raise exc.SQLAlchemyError( + f"{self.__class__.__name__} option can only be applied to Table" + ) + parent.dialect_options[DIALECT_NAME][self.__option_name__] = self + + def _set_table_option_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + pass diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py new file mode 100644 index 00000000..54008ec8 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from enum import Enum + + +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 + + +class TableOptionBase: + __option_name__ = "default" + __visit_name__ = __option_name__ + __priority__ = Priority.MEDIUM + + @staticmethod + def template() -> str: + raise NotImplementedError + + def get_expression(self): + raise NotImplementedError + + def render_option(self, compiler) -> str: + raise NotImplementedError diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py new file mode 100644 index 00000000..4331a4cb --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional + +from .table_option import TableOption +from .table_option_base import Priority + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hour" + DAYS = "days" + + +class TargetLag(TableOption): + """Class to represent the target lag clause. + This configuration option is used to specify the target lag time for the dynamic table. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + + Target lag example usage: + DynamicTable("sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + TargetLag(20, TimeUnit.MINUTES), + ) + """ + + __option_name__ = "target_lag" + __priority__ = Priority.HIGH + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + down_stream: Optional[bool] = False, + ) -> None: + self.time = time + self.unit = unit + self.down_stream = down_stream + + @staticmethod + def template() -> str: + return "TARGET_LAG = %s" + + def get_expression(self): + return ( + f"'{str(self.time)} {str(self.unit.value)}'" + if not self.down_stream + else "DOWNSTREAM" + ) + + def render_option(self, compiler) -> str: + return TargetLag.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "TargetLag(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py new file mode 100644 index 00000000..a5b8cce0 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption +from .table_option_base import Priority + + +class Warehouse(TableOption): + """Class to represent the warehouse clause. + This configuration option is used to specify the warehouse for the dynamic table. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + + Warehouse example usage: + DynamicTable("sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + Warehouse('my_warehouse_name') + ) + """ + + __option_name__ = "warehouse" + __priority__ = Priority.HIGH + + def __init__( + self, + name: Optional[str], + ) -> None: + r"""Construct a Warehouse object. + + :param \*expressions: + Dynamic table warehouse option. + WAREHOUSE = + + """ + self.name = name + + @staticmethod + def template() -> str: + return "WAREHOUSE = %s" + + def get_expression(self): + return self.name + + def render_option(self, compiler) -> str: + return Warehouse.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "Warehouse(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py new file mode 100644 index 00000000..60e8995f --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, Optional + +from sqlalchemy.sql import Selectable +from sqlalchemy.sql.schema import Column, MetaData, SchemaItem +from sqlalchemy.util import NoneType + +from .custom_table_base import CustomTableBase +from .options.as_query import AsQuery + + +class TableFromQueryBase(CustomTableBase): + + @property + def as_query(self): + return self._get_dialect_option(AsQuery.__option_name__) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + items = [item for item in args] + as_query: AsQuery = self.__get_as_query_from_items(items) + if ( + as_query is not NoneType + and isinstance(as_query.query, Selectable) + and not self.__has_defined_columns(items) + ): + columns = self.__create_columns_from_selectable(as_query.query) + args = items + columns + super().__init__(name, metadata, *args, **kw) + + def __get_as_query_from_items( + self, items: typing.List[SchemaItem] + ) -> Optional[AsQuery]: + for item in items: + if isinstance(item, AsQuery): + return item + return NoneType + + def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: + for item in items: + if isinstance(item, Column): + return True + + def __create_columns_from_selectable( + self, selectable: Selectable + ) -> Optional[typing.List[Column]]: + if not isinstance(selectable, Selectable): + return + columns: typing.List[Column] = [] + for _, c in selectable.exported_columns.items(): + columns += [Column(c.name, c.type)] + return columns diff --git a/tests/__snapshots__/test_compile_dynamic_table.ambr b/tests/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/__snapshots__/test_reflect_dynamic_table.ambr b/tests/__snapshots__/test_reflect_dynamic_table.ambr new file mode 100644 index 00000000..d4cc22b5 --- /dev/null +++ b/tests/__snapshots__/test_reflect_dynamic_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- diff --git a/tests/test_compile_dynamic_table.py b/tests/test_compile_dynamic_table.py new file mode 100644 index 00000000..73ce54aa --- /dev/null +++ b/tests/test_compile_dynamic_table.py @@ -0,0 +1,177 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + Table, + exc, + select, +) +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, DynamicTable +from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( + TargetLag, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse + + +def test_compile_dynamic_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_without_required_args(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="DYNAMIC TABLE must have the following arguments: TargetLag, " + "Warehouse, AsQuery", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + + +def test_compile_dynamic_table_with_primary_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + +def test_compile_dynamic_table_with_foreign_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ForeignKeyConstraint(["id"], ["table.id"]), + ) + + +def test_compile_dynamic_table_orm(sql_compiler, snapshot): + Base = declarative_base() + metadata = MetaData() + table_name = "test_dynamic_table_orm" + test_dynamic_table_orm = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + class TestDynamicTableOrm(Base): + __table__ = test_dynamic_table_orm + __mapper_args__ = { + "primary_key": [test_dynamic_table_orm.c.id, test_dynamic_table_orm.c.name] + } + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_dynamic_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return DynamicTable(name, metadata, *arg, **kw) + + __table_args__ = ( + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + id = Column(Integer) + name = Column(String) + + __mapper_args__ = {"primary_key": [id, name]} + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = Table( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + dynamic_test_table = DynamicTable( + "dynamic_test_table_1", + Base.metadata, + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery(select(test_table_1).where(test_table_1.c.id == 23)), + ) + + value = CreateTable(dynamic_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/test_create_dynamic_table.py b/tests/test_create_dynamic_table.py new file mode 100644 index 00000000..4e6c48ca --- /dev/null +++ b/tests/test_create_dynamic_table.py @@ -0,0 +1,93 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( + TargetLag, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse + + +def test_create_dynamic_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(1, TimeUnit.HOURS), + Warehouse(warehouse), + AsQuery("SELECT id, name from test_table_1;"), + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_dynamic_table_without_dynamictable_class( + engine_testaccount, db_parameters +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(1, TimeUnit.HOURS), + Warehouse(warehouse), + AsQuery("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_reflect_dynamic_table.py b/tests/test_reflect_dynamic_table.py new file mode 100644 index 00000000..8a4a8445 --- /dev/null +++ b/tests/test_reflect_dynamic_table.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.custom_commands import NoneType + + +def test_simple_reflection_dynamic_table_as_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = Table( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_without_options_loading(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = DynamicTable( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + # TODO: Add support for loading options when table is reflected + assert dynamic_test_table.warehouse is NoneType + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount)