diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 0afd44a5..dedbab1e 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -62,7 +62,16 @@ VARIANT, ) from .sql.custom_schema import DynamicTable, HybridTable -from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse +from .sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + TimeUnit, +) from .util import _url as URL base.dialect = dialect = snowdialect.dialect @@ -70,6 +79,7 @@ __version__ = importlib_metadata.version("snowflake-sqlalchemy") __all__ = ( + # Custom Types "BIGINT", "BINARY", "BOOLEAN", @@ -104,6 +114,7 @@ "TINYINT", "VARBINARY", "VARIANT", + # Custom Commands "MergeInto", "CSVFormatter", "JSONFormatter", @@ -115,10 +126,17 @@ "ExternalStage", "CreateStage", "CreateFileFormat", + # Custom Tables + "HybridTable", "DynamicTable", - "AsQuery", - "TargetLag", + # Custom Schema Options + "AsQueryOption", + "TargetLagOption", + "LiteralOption", + "IdentifierOption", + "KeywordOption", + # Enums "TimeUnit", - "Warehouse", - "HybridTable", + "TableOptionKey", + "SnowflakeKeyword", ) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 56631728..023f7afb 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -5,6 +5,7 @@ import itertools import operator import re +from typing import List from sqlalchemy import exc as sa_exc from sqlalchemy import inspect, sql @@ -26,8 +27,13 @@ ExternalStage, ) +from .exc import ( + CustomOptionsAreOnlySupportedOnSnowflakeTables, + UnexpectedOptionTypeError, +) from .functions import flatten -from .sql.custom_schema.options.table_option_base import TableOptionBase +from .sql.custom_schema.custom_table_base import CustomTableBase +from .sql.custom_schema.options.table_option import TableOption from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -925,16 +931,24 @@ def handle_cluster_by(self, table): def post_create_table(self, table): text = self.handle_cluster_by(table) - options = [ - option - for _, option in table.dialect_options[DIALECT_NAME].items() - if isinstance(option, TableOptionBase) - ] - options.sort( - key=lambda x: (x.__priority__.value, x.__option_name__), reverse=True - ) - for option in options: - text += "\t" + option.render_option(self) + options = [] + invalid_options: List[str] = [] + + for key, option in table.dialect_options[DIALECT_NAME].items(): + if isinstance(option, TableOption): + options.append(option) + elif key not in ["clusterby", "*"]: + invalid_options.append(key) + + if len(invalid_options) > 0: + raise UnexpectedOptionTypeError(sorted(invalid_options)) + + if isinstance(table, CustomTableBase): + options.sort(key=lambda x: (x.priority.value, x.option_name), reverse=True) + for option in options: + text += "\t" + option.render_option(self) + elif len(options) > 0: + raise CustomOptionsAreOnlySupportedOnSnowflakeTables() return text diff --git a/src/snowflake/sqlalchemy/exc.py b/src/snowflake/sqlalchemy/exc.py new file mode 100644 index 00000000..d5f31786 --- /dev/null +++ b/src/snowflake/sqlalchemy/exc.py @@ -0,0 +1,74 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from typing import List + +from sqlalchemy.exc import ArgumentError + + +class NoPrimaryKeyError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Table {target} required primary key.") + + +class UnsupportedPrimaryKeysAndForeignKeysError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Primary key and foreign keys are not supported in {target}.") + + +class RequiredParametersNotProvidedError(ArgumentError): + def __init__(self, target: str, parameters: List[str]): + super().__init__( + f"{target} requires the following parameters: %s." % ", ".join(parameters) + ) + + +class UnexpectedTableOptionKeyError(ArgumentError): + def __init__(self, expected: str, actual: str): + super().__init__(f"Expected table option {expected} but got {actual}.") + + +class OptionKeyNotProvidedError(ArgumentError): + def __init__(self, target: str): + super().__init__( + f"Expected option key in {target} option but got NoneType instead." + ) + + +class UnexpectedOptionParameterTypeError(ArgumentError): + def __init__(self, parameter_name: str, target: str, types: List[str]): + super().__init__( + f"Parameter {parameter_name} of {target} requires to be one" + f" of following types: {', '.join(types)}." + ) + + +class CustomOptionsAreOnlySupportedOnSnowflakeTables(ArgumentError): + def __init__(self): + super().__init__( + "Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables." + ) + + +class UnexpectedOptionTypeError(ArgumentError): + def __init__(self, options: List[str]): + super().__init__( + f"The following options are either unsupported or should be defined using a Snowflake table: {', '.join(options)}." + ) + + +class InvalidTableParameterTypeError(ArgumentError): + def __init__(self, name: str, input_type: str, expected_types: List[str]): + expected_types_str = "', '".join(expected_types) + super().__init__( + f"Invalid parameter type '{input_type}' provided for '{name}'. " + f"Expected one of the following types: '{expected_types_str}'.\n" + ) + + +class MultipleErrors(ArgumentError): + def __init__(self, errors): + self.errors = errors + + def __str__(self): + return "\n ".join(str(e) for e in self.errors) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py index b61c270d..b75dc7bf 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -2,21 +2,29 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import typing -from typing import Any +from typing import Any, List -from sqlalchemy.exc import ArgumentError from sqlalchemy.sql.schema import MetaData, SchemaItem, Table from ..._constants import DIALECT_NAME from ...compat import IS_VERSION_20 from ...custom_commands import NoneType +from ...exc import ( + MultipleErrors, + NoPrimaryKeyError, + RequiredParametersNotProvidedError, + UnsupportedPrimaryKeysAndForeignKeysError, +) from .custom_table_prefix import CustomTablePrefix -from .options.table_option import TableOption +from .options.invalid_table_option import InvalidTableOption +from .options.table_option import TableOption, TableOptionKey class CustomTableBase(Table): __table_prefixes__: typing.List[CustomTablePrefix] = [] _support_primary_and_foreign_keys: bool = True + _enforce_primary_keys: bool = False + _required_parameters: List[TableOptionKey] = [] @property def table_prefixes(self) -> typing.List[str]: @@ -32,7 +40,9 @@ def __init__( if len(self.__table_prefixes__) > 0: prefixes = kw.get("prefixes", []) + self.table_prefixes kw.update(prefixes=prefixes) + if not IS_VERSION_20 and hasattr(super(), "_init"): + kw.pop("_no_init", True) super()._init(name, metadata, *args, **kw) else: super().__init__(name, metadata, *args, **kw) @@ -40,20 +50,64 @@ def __init__( if not kw.get("autoload_with", False): self._validate_table() + def _append_parameter_error( + self, parameter: str, expected_argument: str, current_argument: str + ) -> None: + if not hasattr(self, "_parameter_error"): + self._parameter_error = [] + self._parameter_error.append((parameter, expected_argument, current_argument)) + def _validate_table(self): + exceptions: List[Exception] = [] + + for _, option in self.dialect_options[DIALECT_NAME].items(): + if isinstance(option, InvalidTableOption): + exceptions.append(option.exception) + + if isinstance(self.key, NoneType) and self._enforce_primary_keys: + exceptions.append(NoPrimaryKeyError(self.__class__.__name__)) + missing_parameters: List[str] = [] + + for required_parameter in self._required_parameters: + if isinstance(self._get_dialect_option(required_parameter), NoneType): + missing_parameters.append(required_parameter.name.lower()) + if missing_parameters: + exceptions.append( + RequiredParametersNotProvidedError( + self.__class__.__name__, missing_parameters + ) + ) + if not self._support_primary_and_foreign_keys and ( self.primary_key or self.foreign_keys ): - raise ArgumentError( - f"Primary key and foreign keys are not supported in {' '.join(self.table_prefixes)} TABLE." + exceptions.append( + UnsupportedPrimaryKeysAndForeignKeysError(self.__class__.__name__) ) - return True + if len(exceptions) > 1: + exceptions.sort(key=lambda e: str(e)) + raise MultipleErrors(exceptions) + elif len(exceptions) == 1: + raise exceptions[0] + + def _get_dialect_option( + self, option_name: TableOptionKey + ) -> typing.Optional[TableOption]: + if option_name.value in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name.value] + return None - def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: - if option_name in self.dialect_options[DIALECT_NAME]: - return self.dialect_options[DIALECT_NAME][option_name] - return NoneType + def _as_dialect_options( + self, table_options: List[TableOption] + ) -> typing.Dict[str, TableOption]: + result = {} + for table_option in table_options: + if isinstance(table_option, TableOption) and isinstance( + table_option.option_name, str + ): + result[DIALECT_NAME + "_" + table_option.option_name] = table_option + return result @classmethod def is_equal_type(cls, table: Table) -> bool: diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py index 1a2248fc..6db4312d 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -3,16 +3,21 @@ # import typing -from typing import Any +from typing import Any, Union -from sqlalchemy.exc import ArgumentError from sqlalchemy.sql.schema import MetaData, SchemaItem -from snowflake.sqlalchemy.custom_commands import NoneType - from .custom_table_prefix import CustomTablePrefix -from .options.target_lag import TargetLag -from .options.warehouse import Warehouse +from .options import ( + IdentifierOption, + IdentifierOptionType, + KeywordOptionType, + LiteralOption, + TableOptionKey, + TargetLagOption, + TargetLagOptionType, +) +from .options.keyword_option import KeywordOption from .table_from_query import TableFromQueryBase @@ -26,29 +31,69 @@ class DynamicTable(TableFromQueryBase): While it does not support reflection at this time, it provides a flexible interface for creating dynamic tables and management. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using option values: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), + warehouse='warehouse_name', + refresh_mode=SnowflakeKeyword.AUTO + as_query="SELECT id, name from test_table_1;" + ) + + Example using full options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=TargetLag(1, TimeUnit.HOURS), + warehouse=Identifier('warehouse_name'), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO) + as_query=AsQuery("SELECT id, name from test_table_1;") + ) """ __table_prefixes__ = [CustomTablePrefix.DYNAMIC] - _support_primary_and_foreign_keys = False + _required_parameters = [ + TableOptionKey.WAREHOUSE, + TableOptionKey.AS_QUERY, + TableOptionKey.TARGET_LAG, + ] @property - def warehouse(self) -> typing.Optional[Warehouse]: - return self._get_dialect_option(Warehouse.__option_name__) + def warehouse(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.WAREHOUSE) @property - def target_lag(self) -> typing.Optional[TargetLag]: - return self._get_dialect_option(TargetLag.__option_name__) + def target_lag(self) -> typing.Optional[TargetLagOption]: + return self._get_dialect_option(TableOptionKey.TARGET_LAG) def __init__( self, name: str, metadata: MetaData, *args: SchemaItem, + warehouse: IdentifierOptionType = None, + target_lag: Union[TargetLagOptionType, KeywordOptionType] = None, + refresh_mode: KeywordOptionType = None, **kw: Any, ) -> None: if kw.get("_no_init", True): return + + options = [ + IdentifierOption.create(TableOptionKey.WAREHOUSE, warehouse), + TargetLagOption.create(target_lag), + KeywordOption.create(TableOptionKey.REFRESH_MODE, refresh_mode), + ] + + kw.update(self._as_dialect_options(options)) super().__init__(name, metadata, *args, **kw) def _init( @@ -58,22 +103,7 @@ def _init( *args: SchemaItem, **kw: Any, ) -> None: - super().__init__(name, metadata, *args, **kw) - - def _validate_table(self): - missing_attributes = [] - if self.target_lag is NoneType: - missing_attributes.append("TargetLag") - if self.warehouse is NoneType: - missing_attributes.append("Warehouse") - if self.as_query is NoneType: - missing_attributes.append("AsQuery") - if missing_attributes: - raise ArgumentError( - "DYNAMIC TABLE must have the following arguments: %s" - % ", ".join(missing_attributes) - ) - super()._validate_table() + self.__init__(name, metadata, *args, _no_init=False, **kw) def __repr__(self) -> str: return "DynamicTable(%s)" % ", ".join( diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py index bd49a420..b7c29e78 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -4,11 +4,8 @@ from typing import Any -from sqlalchemy.exc import ArgumentError from sqlalchemy.sql.schema import MetaData, SchemaItem -from snowflake.sqlalchemy.custom_commands import NoneType - from .custom_table_base import CustomTableBase from .custom_table_prefix import CustomTablePrefix @@ -21,11 +18,20 @@ class HybridTable(CustomTableBase): While it does not support reflection at this time, it provides a flexible interface for creating dynamic tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table + + Example usage: + HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String) + ) """ __table_prefixes__ = [CustomTablePrefix.HYBRID] - - _support_primary_and_foreign_keys = True + _enforce_primary_keys: bool = True def __init__( self, @@ -45,18 +51,7 @@ def _init( *args: SchemaItem, **kw: Any, ) -> None: - super().__init__(name, metadata, *args, **kw) - - def _validate_table(self): - missing_attributes = [] - if self.key is NoneType: - missing_attributes.append("Primary Key") - if missing_attributes: - raise ArgumentError( - "HYBRID TABLE must have the following arguments: %s" - % ", ".join(missing_attributes) - ) - super()._validate_table() + self.__init__(name, metadata, *args, _no_init=False, **kw) def __repr__(self) -> str: return "HybridTable(%s)" % ", ".join( diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py index 052e2d96..11b54c1a 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -2,8 +2,29 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from .as_query import AsQuery -from .target_lag import TargetLag, TimeUnit -from .warehouse import Warehouse +from .as_query_option import AsQueryOption, AsQueryOptionType +from .identifier_option import IdentifierOption, IdentifierOptionType +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .literal_option import LiteralOption, LiteralOptionType +from .table_option import TableOptionKey +from .target_lag_option import TargetLagOption, TargetLagOptionType, TimeUnit -__all__ = ["Warehouse", "AsQuery", "TargetLag", "TimeUnit"] +__all__ = [ + # Options + "IdentifierOption", + "LiteralOption", + "KeywordOption", + "AsQueryOption", + "TargetLagOption", + # Enums + "TimeUnit", + "SnowflakeKeyword", + "TableOptionKey", + # Types + "IdentifierOptionType", + "LiteralOptionType", + "AsQueryOptionType", + "TargetLagOptionType", + "KeywordOptionType", +] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py deleted file mode 100644 index 68076af9..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py +++ /dev/null @@ -1,62 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# -from typing import Union - -from sqlalchemy.sql import Selectable - -from .table_option import TableOption -from .table_option_base import Priority - - -class AsQuery(TableOption): - """Class to represent an AS clause in tables. - This configuration option is used to specify the query from which the table is created. - For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas - - - AsQuery example usage using an input string: - DynamicTable( - "sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - AsQuery('select name, address from existing_table where name = "test"') - ) - - AsQuery example usage using a selectable statement: - DynamicTable( - "sometable", - Base.metadata, - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery(select(test_table_1).where(test_table_1.c.id == 23)) - ) - - """ - - __option_name__ = "as_query" - __priority__ = Priority.LOWEST - - def __init__(self, query: Union[str, Selectable]) -> None: - r"""Construct an as_query object. - - :param \*expressions: - AS - - """ - self.query = query - - @staticmethod - def template() -> str: - return "AS %s" - - def get_expression(self): - if isinstance(self.query, Selectable): - return self.query.compile(compile_kwargs={"literal_binds": True}) - return self.query - - def render_option(self, compiler) -> str: - return AsQuery.template() % (self.get_expression()) - - def __repr__(self) -> str: - return "Query(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py new file mode 100644 index 00000000..70adb4a9 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from sqlalchemy.sql import Selectable + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class AsQueryOption(TableOption): + """Class to represent an AS clause in tables. + This configuration option is used to specify the query from which the table is created. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + Example: + as_query=AsQuery('select name, address from existing_table where name = "test"') + + is equivalent to: + + as select name, address from existing_table where name = "test" + """ + + def __init__(self, query: Union[str, Selectable]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.AS_QUERY + self.query = query + + @staticmethod + def create( + value: Optional[Union["AsQueryOption", str, Selectable]] + ) -> "TableOption": + if isinstance(value, NoneType) or isinstance(value, AsQueryOption): + return value + if isinstance(value, str) or isinstance(value, Selectable): + return AsQueryOption(value) + return TableOption._get_invalid_table_option( + TableOptionKey.AS_QUERY, + str(type(value).__name__), + [AsQueryOption.__name__, str.__name__, Selectable.__name__], + ) + + def template(self) -> str: + return "AS %s" + + @property + def priority(self) -> Priority: + return Priority.LOWEST + + def __get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "AsQuery(%s)" % self.__get_expression() + + +AsQueryOptionType = Union[AsQueryOption, str, Selectable] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py new file mode 100644 index 00000000..dad34cbe --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class IdentifierOption(TableOption): + """Class to represent an identifier option in Snowflake Tables. + + Example: + warehouse = Identifier('my_warehouse') + + is equivalent to: + + WAREHOUSE = my_warehouse + """ + + def __init__(self, value: Union[str]) -> None: + super().__init__() + self.value: str = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, "IdentifierOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + + if isinstance(value, str): + value = IdentifierOption(value) + + if isinstance(value, IdentifierOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, str(type(value).__name__), [IdentifierOption.__name__, str.__name__] + ) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"IdentifierOption(value='{self.value}'{option_name})" + + +IdentifierOptionType = Union[IdentifierOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py new file mode 100644 index 00000000..2bdc9dd3 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption, TableOptionKey + + +class InvalidTableOption(TableOption): + """Class to store errors and raise them after table initialization in order to avoid recursion error.""" + + def __init__(self, name: TableOptionKey, value: Exception) -> None: + super().__init__() + self.exception: Exception = value + self._name = name + + @staticmethod + def create(name: TableOptionKey, value: Exception) -> Optional[TableOption]: + return InvalidTableOption(name, value) + + def _render(self, compiler) -> str: + raise self.exception + + def __repr__(self) -> str: + return f"ErrorOption(value='{self.exception}')" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py new file mode 100644 index 00000000..391dc5c5 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class KeywordOption(TableOption): + """Class to represent a keyword option in Snowflake Tables. + + Example: + target_lag = Keyword(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + """ + + def __init__(self, value: Union[SnowflakeKeyword]) -> None: + super().__init__() + self.value: str = value.value + + @property + def priority(self): + return Priority.HIGH + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value.upper() + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[SnowflakeKeyword, "KeywordOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + if isinstance(value, SnowflakeKeyword): + value = KeywordOption(value) + + if isinstance(value, KeywordOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [KeywordOption.__name__, SnowflakeKeyword.__name__], + ) + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if isinstance(self.option_name, NoneType) + else "" + ) + return f"KeywordOption(value='{self.value}'{option_name})" + + +KeywordOptionType = Union[KeywordOption, SnowflakeKeyword] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py new file mode 100644 index 00000000..f4ba87ea --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class SnowflakeKeyword(Enum): + # TARGET_LAG + DOWNSTREAM = "DOWNSTREAM" + + # REFRESH_MODE + AUTO = "AUTO" + FULL = "FULL" + INCREMENTAL = "INCREMENTAL" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py new file mode 100644 index 00000000..de15473d --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any, Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class LiteralOption(TableOption): + """Class to represent a literal option in Snowflake Table. + + Example: + warehouse = Literal('my_warehouse') + + is equivalent to: + + WAREHOUSE = 'my_warehouse' + """ + + def __init__(self, value: Union[int, str]) -> None: + super().__init__() + self.value: Any = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, int, "LiteralOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + if isinstance(value, (str, int)): + value = LiteralOption(value) + + if isinstance(value, LiteralOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [LiteralOption.__name__, str.__name__, int.__name__], + ) + + def template(self) -> str: + if isinstance(self.value, int): + return f"{self.option_name.upper()} = %d" + else: + return f"{self.option_name.upper()} = '%s'" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"LiteralOption(value='{self.value}'{option_name})" + + +LiteralOptionType = Union[LiteralOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py index 7ac27575..14b91f2e 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -1,26 +1,83 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from typing import Any +from enum import Enum +from typing import List, Optional -from sqlalchemy import exc -from sqlalchemy.sql.base import SchemaEventTarget -from sqlalchemy.sql.schema import SchemaItem, Table +from snowflake.sqlalchemy import exc +from snowflake.sqlalchemy.custom_commands import NoneType -from snowflake.sqlalchemy._constants import DIALECT_NAME -from .table_option_base import TableOptionBase +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 -class TableOption(TableOptionBase, SchemaItem): - def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: - if self.__option_name__ == "default": - raise exc.SQLAlchemyError(f"{self.__class__.__name__} does not has a name") - if not isinstance(parent, Table): - raise exc.SQLAlchemyError( - f"{self.__class__.__name__} option can only be applied to Table" - ) - parent.dialect_options[DIALECT_NAME][self.__option_name__] = self +class TableOption: - def _set_table_option_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: - pass + def __init__(self) -> None: + self._name: Optional[TableOptionKey] = None + + @property + def option_name(self) -> str: + if isinstance(self._name, NoneType): + return None + return str(self._name.value) + + def _set_option_name(self, name: Optional["TableOptionKey"]): + self._name = name + + @property + def priority(self) -> Priority: + return Priority.MEDIUM + + @staticmethod + def create(**kwargs) -> "TableOption": + raise NotImplementedError + + @staticmethod + def _get_invalid_table_option( + parameter_name: "TableOptionKey", input_type: str, expected_types: List[str] + ) -> "TableOption": + from .invalid_table_option import InvalidTableOption + + return InvalidTableOption( + parameter_name, + exc.InvalidTableParameterTypeError( + parameter_name.value, input_type, expected_types + ), + ) + + def _validate_option(self): + if isinstance(self.option_name, NoneType): + raise exc.OptionKeyNotProvidedError(self.__class__.__name__) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def render_option(self, compiler) -> str: + self._validate_option() + return self._render(compiler) + + def _render(self, compiler) -> str: + raise NotImplementedError + + +class TableOptionKey(Enum): + AS_QUERY = "as_query" + BASE_LOCATION = "base_location" + CATALOG = "catalog" + CATALOG_SYNC = "catalog_sync" + DATA_RETENTION_TIME_IN_DAYS = "data_retention_time_in_days" + DEFAULT_DDL_COLLATION = "default_ddl_collation" + EXTERNAL_VOLUME = "external_volume" + MAX_DATA_EXTENSION_TIME_IN_DAYS = "max_data_extension_time_in_days" + REFRESH_MODE = "refresh_mode" + STORAGE_SERIALIZATION_POLICY = "storage_serialization_policy" + TARGET_LAG = "target_lag" + WAREHOUSE = "warehouse" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py deleted file mode 100644 index 54008ec8..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# -from enum import Enum - - -class Priority(Enum): - LOWEST = 0 - VERY_LOW = 1 - LOW = 2 - MEDIUM = 4 - HIGH = 6 - VERY_HIGH = 7 - HIGHEST = 8 - - -class TableOptionBase: - __option_name__ = "default" - __visit_name__ = __option_name__ - __priority__ = Priority.MEDIUM - - @staticmethod - def template() -> str: - raise NotImplementedError - - def get_expression(self): - raise NotImplementedError - - def render_option(self, compiler) -> str: - raise NotImplementedError diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py deleted file mode 100644 index 4331a4cb..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py +++ /dev/null @@ -1,60 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# from enum import Enum -from enum import Enum -from typing import Optional - -from .table_option import TableOption -from .table_option_base import Priority - - -class TimeUnit(Enum): - SECONDS = "seconds" - MINUTES = "minutes" - HOURS = "hour" - DAYS = "days" - - -class TargetLag(TableOption): - """Class to represent the target lag clause. - This configuration option is used to specify the target lag time for the dynamic table. - For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table - - - Target lag example usage: - DynamicTable("sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - TargetLag(20, TimeUnit.MINUTES), - ) - """ - - __option_name__ = "target_lag" - __priority__ = Priority.HIGH - - def __init__( - self, - time: Optional[int] = 0, - unit: Optional[TimeUnit] = TimeUnit.MINUTES, - down_stream: Optional[bool] = False, - ) -> None: - self.time = time - self.unit = unit - self.down_stream = down_stream - - @staticmethod - def template() -> str: - return "TARGET_LAG = %s" - - def get_expression(self): - return ( - f"'{str(self.time)} {str(self.unit.value)}'" - if not self.down_stream - else "DOWNSTREAM" - ) - - def render_option(self, compiler) -> str: - return TargetLag.template() % (self.get_expression()) - - def __repr__(self) -> str: - return "TargetLag(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py new file mode 100644 index 00000000..2088c729 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py @@ -0,0 +1,96 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional, Tuple, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hours" + DAYS = "days" + + +class TargetLagOption(TableOption): + """Class to represent the target lag clause. + This configuration option is used to specify the target lag time for the dynamic table. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + + Example using the time and unit parameters: + + target_lag = TargetLag(10, TimeUnit.SECONDS) + + is equivalent to: + + TARGET_LAG = '10 SECONDS' + + Example using keyword parameter: + + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + + """ + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + ) -> None: + super().__init__() + self.time = time + self.unit = unit + self._name: TableOptionKey = TableOptionKey.TARGET_LAG + + @staticmethod + def create( + value: Union["TargetLagOption", Tuple[int, TimeUnit], KeywordOptionType] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + + if isinstance(value, Tuple): + time, unit = value + value = TargetLagOption(time, unit) + + if isinstance(value, TargetLagOption): + return value + + if isinstance(value, (KeywordOption, SnowflakeKeyword)): + return KeywordOption.create(TableOptionKey.TARGET_LAG, value) + + return TableOption._get_invalid_table_option( + TableOptionKey.TARGET_LAG, + str(type(value).__name__), + [ + TargetLagOption.__name__, + f"Tuple[int, {TimeUnit.__name__}])", + SnowflakeKeyword.__name__, + ], + ) + + def __get_expression(self): + return f"'{str(self.time)} {str(self.unit.value)}'" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "TargetLag(%s)" % self.__get_expression() + + +TargetLagOptionType = Union[TargetLagOption, Tuple[int, TimeUnit]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py deleted file mode 100644 index a5b8cce0..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py +++ /dev/null @@ -1,51 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# -from typing import Optional - -from .table_option import TableOption -from .table_option_base import Priority - - -class Warehouse(TableOption): - """Class to represent the warehouse clause. - This configuration option is used to specify the warehouse for the dynamic table. - For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table - - - Warehouse example usage: - DynamicTable("sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - Warehouse('my_warehouse_name') - ) - """ - - __option_name__ = "warehouse" - __priority__ = Priority.HIGH - - def __init__( - self, - name: Optional[str], - ) -> None: - r"""Construct a Warehouse object. - - :param \*expressions: - Dynamic table warehouse option. - WAREHOUSE = - - """ - self.name = name - - @staticmethod - def template() -> str: - return "WAREHOUSE = %s" - - def get_expression(self): - return self.name - - def render_option(self, compiler) -> str: - return Warehouse.template() % (self.get_expression()) - - def __repr__(self) -> str: - return "Warehouse(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py index 60e8995f..fccc7a0b 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -6,29 +6,31 @@ from sqlalchemy.sql import Selectable from sqlalchemy.sql.schema import Column, MetaData, SchemaItem -from sqlalchemy.util import NoneType from .custom_table_base import CustomTableBase -from .options.as_query import AsQuery +from .options.as_query_option import AsQueryOption, AsQueryOptionType +from .options.table_option import TableOptionKey class TableFromQueryBase(CustomTableBase): @property - def as_query(self): - return self._get_dialect_option(AsQuery.__option_name__) + def as_query(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.AS_QUERY) def __init__( self, name: str, metadata: MetaData, *args: SchemaItem, + as_query: AsQueryOptionType = None, **kw: Any, ) -> None: items = [item for item in args] - as_query: AsQuery = self.__get_as_query_from_items(items) + as_query = AsQueryOption.create(as_query) # noqa + kw.update(self._as_dialect_options([as_query])) if ( - as_query is not NoneType + isinstance(as_query, AsQueryOption) and isinstance(as_query.query, Selectable) and not self.__has_defined_columns(items) ): @@ -36,14 +38,6 @@ def __init__( args = items + columns super().__init__(name, metadata, *args, **kw) - def __get_as_query_from_items( - self, items: typing.List[SchemaItem] - ) -> Optional[AsQuery]: - for item in items: - if isinstance(item, AsQuery): - return item - return NoneType - def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: for item in items: if isinstance(item, Column): diff --git a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr index 81c7f90f..95d2e5c6 100644 --- a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr +++ b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr @@ -6,8 +6,50 @@ "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" # --- # name: test_compile_dynamic_table_orm_with_str_keys - "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" + 'CREATE DYNAMIC TABLE "SCHEMA_DB".test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = \'10 seconds\'\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_multiple_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'refresh_mode'. Expected one of the following types: 'KeywordOption', 'SnowflakeKeyword'. + + Invalid parameter type 'IdentifierOption' provided for 'target_lag'. Expected one of the following types: 'TargetLagOption', 'Tuple[int, TimeUnit])', 'SnowflakeKeyword'. + + Invalid parameter type 'KeywordOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + + Invalid parameter type 'KeywordOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_one_wrong_option_types + ''' + Invalid parameter type 'LiteralOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_options_objects + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.AUTO] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.FULL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = FULL\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.INCREMENTAL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = INCREMENTAL\tAS SELECT * FROM table" # --- # name: test_compile_dynamic_table_with_selectable "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" # --- +# name: test_compile_dynamic_table_with_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'refresh_mode'. Expected one of the following types: 'KeywordOption', 'SnowflakeKeyword'. + + Invalid parameter type 'IdentifierOption' provided for 'target_lag'. Expected one of the following types: 'TargetLagOption', 'Tuple[int, TimeUnit])', 'SnowflakeKeyword'. + + Invalid parameter type 'KeywordOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + + Invalid parameter type 'KeywordOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr new file mode 100644 index 00000000..80201495 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_dynamic_table_without_dynamictable_and_defined_options + CustomOptionsAreOnlySupportedOnSnowflakeTables('Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables.') +# --- +# name: test_create_dynamic_table_without_dynamictable_class + UnexpectedOptionTypeError('The following options are either unsupported or should be defined using a Snowflake table: as_query, warehouse.') +# --- diff --git a/tests/custom_tables/__snapshots__/test_generic_options.ambr b/tests/custom_tables/__snapshots__/test_generic_options.ambr new file mode 100644 index 00000000..fe84351a --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_generic_options.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_identifier_option_with_wrong_type + OptionKeyNotProvidedError('Expected option key in IdentifierOption option but got NoneType instead.') +# --- +# name: test_identifier_option_without_name + OptionKeyNotProvidedError('Expected option key in IdentifierOption option but got NoneType instead.') +# --- +# name: test_invalid_as_query_option + OptionKeyNotProvidedError('Expected option key in IdentifierOption option but got NoneType instead.') +# --- +# name: test_literal_option_with_wrong_type + OptionKeyNotProvidedError('Expected option key in LiteralOption option but got NoneType instead.') +# --- diff --git a/tests/custom_tables/test_compile_dynamic_table.py b/tests/custom_tables/test_compile_dynamic_table.py index 16a039e7..935c61cd 100644 --- a/tests/custom_tables/test_compile_dynamic_table.py +++ b/tests/custom_tables/test_compile_dynamic_table.py @@ -12,16 +12,21 @@ exc, select, ) +from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import declarative_base from sqlalchemy.sql.ddl import CreateTable from snowflake.sqlalchemy import GEOMETRY, DynamicTable -from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery -from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( - TargetLag, +from snowflake.sqlalchemy.exc import MultipleErrors +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + TargetLagOption, TimeUnit, ) -from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword def test_compile_dynamic_table(sql_compiler, snapshot): @@ -32,9 +37,9 @@ def test_compile_dynamic_table(sql_compiler, snapshot): metadata, Column("id", Integer), Column("geom", GEOMETRY), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) value = CreateTable(test_geometry) @@ -44,11 +49,99 @@ def test_compile_dynamic_table(sql_compiler, snapshot): assert actual == snapshot +@pytest.mark.parametrize( + "refresh_mode_keyword", + [ + SnowflakeKeyword.AUTO, + SnowflakeKeyword.FULL, + SnowflakeKeyword.INCREMENTAL, + ], +) +def test_compile_dynamic_table_with_refresh_mode( + sql_compiler, snapshot, refresh_mode_keyword +): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + refresh_mode=refresh_mode_keyword, + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=IdentifierOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(ArgumentError) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=LiteralOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_dynamic_table_with_multiple_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(MultipleErrors) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=IdentifierOption(SnowflakeKeyword.AUTO), + warehouse=KeywordOption(SnowflakeKeyword.AUTO), + as_query=KeywordOption(SnowflakeKeyword.AUTO), + refresh_mode=IdentifierOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + def test_compile_dynamic_table_without_required_args(sql_compiler): with pytest.raises( exc.ArgumentError, - match="DYNAMIC TABLE must have the following arguments: TargetLag, " - "Warehouse, AsQuery", + match="DynamicTable requires the following parameters: warehouse, " + "as_query, target_lag.", ): DynamicTable( "test_dynamic_table", @@ -61,33 +154,33 @@ def test_compile_dynamic_table_without_required_args(sql_compiler): def test_compile_dynamic_table_with_primary_key(sql_compiler): with pytest.raises( exc.ArgumentError, - match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + match="Primary key and foreign keys are not supported in DynamicTable.", ): DynamicTable( "test_dynamic_table", MetaData(), Column("id", Integer, primary_key=True), Column("geom", GEOMETRY), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) def test_compile_dynamic_table_with_foreign_key(sql_compiler): with pytest.raises( exc.ArgumentError, - match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + match="Primary key and foreign keys are not supported in DynamicTable.", ): DynamicTable( "test_dynamic_table", MetaData(), Column("id", Integer), Column("geom", GEOMETRY), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), ForeignKeyConstraint(["id"], ["table.id"]), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) @@ -100,9 +193,9 @@ def test_compile_dynamic_table_orm(sql_compiler, snapshot): metadata, Column("id", Integer), Column("name", String), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) class TestDynamicTableOrm(Base): @@ -121,23 +214,22 @@ def __repr__(self): assert actual == snapshot -def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, db_parameters, snapshot): +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): Base = declarative_base() - schema = db_parameters["schema"] class TestDynamicTableOrm(Base): __tablename__ = "test_dynamic_table_orm_2" - __table_args__ = {"schema": schema} @classmethod def __table_cls__(cls, name, metadata, *arg, **kw): return DynamicTable(name, metadata, *arg, **kw) - __table_args__ = ( - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), - ) + __table_args__ = { + "schema": "SCHEMA_DB", + "target_lag": (10, TimeUnit.SECONDS), + "warehouse": "warehouse", + "as_query": "SELECT * FROM table", + } id = Column(Integer) name = Column(String) @@ -167,9 +259,9 @@ def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): dynamic_test_table = DynamicTable( "dynamic_test_table_1", Base.metadata, - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery(select(test_table_1).where(test_table_1.c.id == 23)), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query=select(test_table_1).where(test_table_1.c.id == 23), ) value = CreateTable(dynamic_test_table) diff --git a/tests/custom_tables/test_create_dynamic_table.py b/tests/custom_tables/test_create_dynamic_table.py index 4e6c48ca..b583faad 100644 --- a/tests/custom_tables/test_create_dynamic_table.py +++ b/tests/custom_tables/test_create_dynamic_table.py @@ -1,15 +1,20 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import pytest from sqlalchemy import Column, Integer, MetaData, String, Table, select -from snowflake.sqlalchemy import DynamicTable -from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery -from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( - TargetLag, +from snowflake.sqlalchemy import DynamicTable, exc +from snowflake.sqlalchemy.sql.custom_schema.options.as_query_option import AsQueryOption +from snowflake.sqlalchemy.sql.custom_schema.options.identifier_option import ( + IdentifierOption, +) +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword +from snowflake.sqlalchemy.sql.custom_schema.options.table_option import TableOptionKey +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag_option import ( + TargetLagOption, TimeUnit, ) -from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse def test_create_dynamic_table(engine_testaccount, db_parameters): @@ -32,9 +37,10 @@ def test_create_dynamic_table(engine_testaccount, db_parameters): metadata, Column("id", Integer), Column("name", String), - TargetLag(1, TimeUnit.HOURS), - Warehouse(warehouse), - AsQuery("SELECT id, name from test_table_1;"), + target_lag=(1, TimeUnit.HOURS), + warehouse=warehouse, + as_query="SELECT id, name from test_table_1;", + refresh_mode=SnowflakeKeyword.FULL, ) metadata.create_all(engine_testaccount) @@ -52,7 +58,7 @@ def test_create_dynamic_table(engine_testaccount, db_parameters): def test_create_dynamic_table_without_dynamictable_class( - engine_testaccount, db_parameters + engine_testaccount, db_parameters, snapshot ): warehouse = db_parameters.get("warehouse", "default") metadata = MetaData() @@ -68,26 +74,51 @@ def test_create_dynamic_table_without_dynamictable_class( conn.execute(ins) conn.commit() - dynamic_test_table_1 = Table( + Table( "dynamic_test_table_1", metadata, Column("id", Integer), Column("name", String), - TargetLag(1, TimeUnit.HOURS), - Warehouse(warehouse), - AsQuery("SELECT id, name from test_table_1;"), + snowflake_warehouse=warehouse, + snowflake_as_query="SELECT id, name from test_table_1;", prefixes=["DYNAMIC"], ) + with pytest.raises(exc.UnexpectedOptionTypeError) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot + + +def test_create_dynamic_table_without_dynamictable_and_defined_options( + engine_testaccount, db_parameters, snapshot +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + metadata.create_all(engine_testaccount) - try: - with engine_testaccount.connect() as conn: - s = select(dynamic_test_table_1) - results_dynamic_table = conn.execute(s).fetchall() - s = select(test_table_1) - results_table = conn.execute(s).fetchall() - assert results_dynamic_table == results_table + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") - finally: - metadata.drop_all(engine_testaccount) + conn.execute(ins) + conn.commit() + + Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + snowflake_target_lag=TargetLagOption.create((1, TimeUnit.HOURS)), + snowflake_warehouse=IdentifierOption.create( + TableOptionKey.WAREHOUSE, warehouse + ), + snowflake_as_query=AsQueryOption.create("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + with pytest.raises(exc.CustomOptionsAreOnlySupportedOnSnowflakeTables) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot diff --git a/tests/custom_tables/test_generic_options.py b/tests/custom_tables/test_generic_options.py new file mode 100644 index 00000000..040d1c25 --- /dev/null +++ b/tests/custom_tables/test_generic_options.py @@ -0,0 +1,80 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest + +from snowflake.sqlalchemy import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + TableOptionKey, + TargetLagOption, + exc, +) +from snowflake.sqlalchemy.sql.custom_schema.options.invalid_table_option import ( + InvalidTableOption, +) + + +def test_identifier_option(): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert identifier.render_option(None) == "WAREHOUSE = xsmall" + + +def test_literal_option(): + literal = LiteralOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert literal.render_option(None) == "WAREHOUSE = 'xsmall'" + + +def test_identifier_option_without_name(snapshot): + identifier = IdentifierOption("xsmall") + with pytest.raises(exc.OptionKeyNotProvidedError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_identifier_option_with_wrong_type(snapshot): + identifier = IdentifierOption(23) + with pytest.raises(exc.OptionKeyNotProvidedError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_literal_option_with_wrong_type(snapshot): + literal = LiteralOption(0.32) + with pytest.raises(exc.OptionKeyNotProvidedError) as exc_info: + literal.render_option(None) + assert exc_info.value == snapshot + + +def test_invalid_as_query_option(snapshot): + identifier = IdentifierOption(23) + with pytest.raises(exc.OptionKeyNotProvidedError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +@pytest.mark.parametrize( + "table_option", + [ + IdentifierOption, + LiteralOption, + KeywordOption, + ], +) +def test_generic_option_with_wrong_type(table_option): + literal = table_option.create(TableOptionKey.WAREHOUSE, 0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" + + +@pytest.mark.parametrize( + "table_option", + [ + TargetLagOption, + AsQueryOption, + ], +) +def test_non_generic_option_with_wrong_type(table_option): + literal = table_option.create(0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" diff --git a/tests/custom_tables/test_reflect_dynamic_table.py b/tests/custom_tables/test_reflect_dynamic_table.py index 8a4a8445..52eb4457 100644 --- a/tests/custom_tables/test_reflect_dynamic_table.py +++ b/tests/custom_tables/test_reflect_dynamic_table.py @@ -74,7 +74,7 @@ def test_simple_reflection_without_options_loading(engine_testaccount, db_parame ) # TODO: Add support for loading options when table is reflected - assert dynamic_test_table.warehouse is NoneType + assert isinstance(dynamic_test_table.warehouse, NoneType) try: with engine_testaccount.connect() as conn: